Repository: GrandaddyShmax/audiocraft_plus
Branch: main
Commit: c9ff851e7b50
Files: 227
Total size: 1.1 MB
Directory structure:
gitextract_7_a5iyu7/
├── .github/
│ └── actions/
│ └── audiocraft_build/
│ └── action.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── LICENSE_weights
├── MANIFEST.in
├── Makefile
├── README.md
├── app.py
├── audiocraft/
│ ├── __init__.py
│ ├── adversarial/
│ │ ├── __init__.py
│ │ ├── discriminators/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── mpd.py
│ │ │ ├── msd.py
│ │ │ └── msstftd.py
│ │ └── losses.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── audio.py
│ │ ├── audio_dataset.py
│ │ ├── audio_utils.py
│ │ ├── info_audio_dataset.py
│ │ ├── music_dataset.py
│ │ ├── sound_dataset.py
│ │ └── zip.py
│ ├── environment.py
│ ├── grids/
│ │ ├── __init__.py
│ │ ├── _base_explorers.py
│ │ ├── audiogen/
│ │ │ ├── __init__.py
│ │ │ ├── audiogen_base_16khz.py
│ │ │ └── audiogen_pretrained_16khz_eval.py
│ │ ├── compression/
│ │ │ ├── __init__.py
│ │ │ ├── _explorers.py
│ │ │ ├── debug.py
│ │ │ ├── encodec_audiogen_16khz.py
│ │ │ ├── encodec_base_24khz.py
│ │ │ └── encodec_musicgen_32khz.py
│ │ ├── diffusion/
│ │ │ ├── 4_bands_base_32khz.py
│ │ │ ├── __init__.py
│ │ │ └── _explorers.py
│ │ └── musicgen/
│ │ ├── __init__.py
│ │ ├── _explorers.py
│ │ ├── musicgen_base_32khz.py
│ │ ├── musicgen_base_cached_32khz.py
│ │ ├── musicgen_clapemb_32khz.py
│ │ ├── musicgen_melody_32khz.py
│ │ └── musicgen_pretrained_32khz_eval.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── balancer.py
│ │ ├── sisnr.py
│ │ ├── specloss.py
│ │ └── stftloss.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── chroma_cosinesim.py
│ │ ├── clap_consistency.py
│ │ ├── fad.py
│ │ ├── kld.py
│ │ ├── rvm.py
│ │ └── visqol.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── audiogen.py
│ │ ├── builders.py
│ │ ├── encodec.py
│ │ ├── lm.py
│ │ ├── loaders.py
│ │ ├── multibanddiffusion.py
│ │ ├── musicgen.py
│ │ └── unet.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── chroma.py
│ │ ├── codebooks_patterns.py
│ │ ├── conditioners.py
│ │ ├── conv.py
│ │ ├── diffusion_schedule.py
│ │ ├── lstm.py
│ │ ├── rope.py
│ │ ├── seanet.py
│ │ ├── streaming.py
│ │ └── transformer.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── cosine_lr_scheduler.py
│ │ ├── dadam.py
│ │ ├── ema.py
│ │ ├── fsdp.py
│ │ ├── inverse_sqrt_lr_scheduler.py
│ │ ├── linear_warmup_lr_scheduler.py
│ │ └── polynomial_decay_lr_scheduler.py
│ ├── py.typed
│ ├── quantization/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── core_vq.py
│ │ └── vq.py
│ ├── solvers/
│ │ ├── __init__.py
│ │ ├── audiogen.py
│ │ ├── base.py
│ │ ├── builders.py
│ │ ├── compression.py
│ │ ├── diffusion.py
│ │ └── musicgen.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── autocast.py
│ ├── best_state.py
│ ├── cache.py
│ ├── checkpoint.py
│ ├── cluster.py
│ ├── deadlock.py
│ ├── export.py
│ ├── export_legacy.py
│ ├── notebook.py
│ ├── profiler.py
│ ├── samples/
│ │ ├── __init__.py
│ │ └── manager.py
│ ├── ui.py
│ └── utils.py
├── config/
│ ├── conditioner/
│ │ ├── chroma2music.yaml
│ │ ├── clapemb2music.yaml
│ │ ├── none.yaml
│ │ ├── text2music.yaml
│ │ └── text2sound.yaml
│ ├── config.yaml
│ ├── dset/
│ │ ├── audio/
│ │ │ ├── audiocaps_16khz.yaml
│ │ │ ├── default.yaml
│ │ │ ├── example.yaml
│ │ │ └── musiccaps_32khz.yaml
│ │ ├── default.yaml
│ │ └── internal/
│ │ ├── music_10k_32khz.yaml
│ │ ├── music_400k_32khz.yaml
│ │ └── sounds_16khz.yaml
│ ├── model/
│ │ ├── encodec/
│ │ │ ├── default.yaml
│ │ │ ├── encodec_base_causal.yaml
│ │ │ ├── encodec_large_nq4_s320.yaml
│ │ │ └── encodec_large_nq4_s640.yaml
│ │ ├── lm/
│ │ │ ├── audiogen_lm.yaml
│ │ │ ├── default.yaml
│ │ │ ├── model_scale/
│ │ │ │ ├── base.yaml
│ │ │ │ ├── large.yaml
│ │ │ │ ├── medium.yaml
│ │ │ │ ├── small.yaml
│ │ │ │ └── xsmall.yaml
│ │ │ └── musicgen_lm.yaml
│ │ ├── none.yaml
│ │ └── score/
│ │ └── basic.yaml
│ ├── solver/
│ │ ├── audiogen/
│ │ │ ├── audiogen_base_16khz.yaml
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ └── evaluation/
│ │ │ ├── none.yaml
│ │ │ └── objective_eval.yaml
│ │ ├── compression/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ ├── encodec_audiogen_16khz.yaml
│ │ │ ├── encodec_base_24khz.yaml
│ │ │ └── encodec_musicgen_32khz.yaml
│ │ ├── default.yaml
│ │ ├── diffusion/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ └── encodec_24khz.yaml
│ │ └── musicgen/
│ │ ├── debug.yaml
│ │ ├── default.yaml
│ │ ├── evaluation/
│ │ │ ├── none.yaml
│ │ │ └── objective_eval.yaml
│ │ ├── musicgen_base_32khz.yaml
│ │ └── musicgen_melody_32khz.yaml
│ └── teams/
│ ├── default.yaml
│ └── labs.yaml
├── dataset/
│ └── example/
│ ├── electro_1.json
│ └── electro_2.json
├── demos/
│ ├── audiogen_demo.ipynb
│ ├── musicgen_app.py
│ └── musicgen_demo.ipynb
├── dockerignore
├── docs/
│ ├── AUDIOGEN.md
│ ├── CONDITIONING.md
│ ├── DATASETS.md
│ ├── ENCODEC.md
│ ├── MBD.md
│ ├── METRICS.md
│ ├── MUSICGEN.md
│ └── TRAINING.md
├── egs/
│ └── example/
│ └── data.jsonl
├── model_cards/
│ ├── AUDIOGEN_MODEL_CARD.md
│ └── MUSICGEN_MODEL_CARD.md
├── models/
│ └── Put your models here.txt
├── mypy.ini
├── requirements.txt
├── scripts/
│ ├── __init__.py
│ ├── mos.py
│ ├── resample_dataset.py
│ ├── static/
│ │ └── style.css
│ └── templates/
│ ├── base.html
│ ├── index.html
│ ├── login.html
│ ├── results.html
│ └── survey.html
├── setup.cfg
├── setup.py
└── tests/
├── __init__.py
├── adversarial/
│ ├── __init__.py
│ ├── test_discriminators.py
│ └── test_losses.py
├── common_utils/
│ ├── __init__.py
│ ├── temp_utils.py
│ └── wav_utils.py
├── data/
│ ├── __init__.py
│ ├── test_audio.py
│ ├── test_audio_dataset.py
│ └── test_audio_utils.py
├── losses/
│ ├── __init__.py
│ └── test_losses.py
├── models/
│ ├── test_audiogen.py
│ ├── test_encodec_model.py
│ ├── test_multibanddiffusion.py
│ └── test_musicgen.py
├── modules/
│ ├── __init__.py
│ ├── test_activations.py
│ ├── test_codebooks_patterns.py
│ ├── test_conv.py
│ ├── test_lstm.py
│ ├── test_rope.py
│ ├── test_seanet.py
│ └── test_transformer.py
├── quantization/
│ └── test_vq.py
└── utils/
└── __init__.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/actions/audiocraft_build/action.yml
================================================
name: audiocraft_build
description: 'Build audiocraft env.'
runs:
using: "composite"
steps:
- uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: actions/cache@v2
id: cache
with:
path: env
key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
- if: ${{ steps.cache.outputs.cache-hit != 'true' }}
name: Install dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install libsndfile1-dev ffmpeg
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install -e '.[dev]'
- name: System Dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install libsndfile1-dev ffmpeg
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__
*.py[cod]
*$py.class
# C extensions
*.so
# macOS dir files
.DS_Store
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.ipynb_checkpoints
# Tests and linter
.pytest_cache/
.mypy_cache/
.coverage
# docs
/api_docs
# dotenv
.env
.envrc
# virtualenv
.venv
venv/
ENV/
# egs with manifest files
egs/*
!egs/example
# local datasets
dataset/*
!dataset/example
# personal notebooks & scripts
*/local_scripts
*/notes
.vscode/
/notebooks
/local_scripts
/notes
/cache
================================================
FILE: CHANGELOG.md
================================================
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [1.0.0] - 2023-08-02
Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Added pretrained model for AudioGen and MultiBandDiffusion.
## [0.0.2] - 2023-08-01
Improved demo, fixed top p (thanks @jnordberg).
Compressor tanh on output to avoid clipping with some style (especially piano).
Now repeating the conditioning periodically if it is too short.
More options when launching Gradio app locally (thanks @ashleykleynhans).
Testing out PyTorch 2.0 memory efficient attention.
Added extended generation (infinite length) by slowly moving the windows.
Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
## [0.0.1] - 2023-06-09
Initial release, with model evaluation only.
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at . All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to AudioCraft
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
AudioCraft is the implementation of a research paper.
Therefore, we do not plan on accepting many pull requests for new features.
We certainly welcome them for bug fixes.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here:
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to encodec, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:11.8.0-base-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONUNBUFFERED=1 \
PYTHONIOENCODING=UTF-8
RUN --mount=type=cache,target=/var/cache/apt --mount=type=cache,target=/var/lib/apt apt update &&\
apt install -y \
wget \
git \
pkg-config \
python3 \
python3-pip \
python-is-python3 \
ffmpeg \
libnvrtc11.2 \
libtcmalloc-minimal4
RUN useradd -m -u 1000 ac
RUN --mount=type=cache,target=/root/.cache python -m pip install --upgrade pip wheel
ENV TORCH_COMMAND="pip install torch==2.0.1+cu118 torchaudio --extra-index-url https://download.pytorch.org/whl/cu118"
RUN --mount=type=cache,target=/root/.cache python -m $TORCH_COMMAND
RUN ln -s /usr/lib/x86_64-linux-gnu/libnvrtc.so.11.2 /usr/lib/x86_64-linux-gnu/libnvrtc.so
USER 1000
RUN mkdir ~/.cache
RUN --mount=type=cache,target=/home/ac/.cache --mount=source=.,target=/home/ac/audiocraft python -m pip install -r /home/ac/audiocraft/requirements.txt
WORKDIR /home/ac/audiocraft
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) Meta Platforms, Inc. and affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: LICENSE_weights
================================================
Attribution-NonCommercial 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial 4.0 International Public
License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
================================================
FILE: MANIFEST.in
================================================
include Makefile
include LICENSE
include LICENSE_weights
include *.md
include *.ini
include requirements.txt
include audiocraft/py.typed
include assets/*.mp3
recursive-include conf *.yaml
================================================
FILE: Makefile
================================================
INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
dataset.train.num_samples=10 dataset.valid.num_samples=10 \
dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
logging.level=DEBUG
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e
INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
checkpoint.save_last=false # Using compression model from 616d7b3c
default: linter tests
install:
pip install -U pip
pip install -U -e '.[dev]'
linter:
flake8 audiocraft && mypy audiocraft
flake8 tests && mypy tests
tests:
coverage run -m pytest tests
coverage report
tests_integ:
$(INTEG_COMPRESSION)
$(INTEG_MBD)
$(INTEG_MUSICGEN)
$(INTEG_AUDIOGEN)
api_docs:
pdoc3 --html -o api_docs -f audiocraft
dist:
python setup.py sdist
.PHONY: linter tests api_docs dist
================================================
FILE: README.md
================================================
# AudioCraft Plus



AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.

## Features
AudioCraft Plus is an all-in-one WebUI for the original AudioCraft, adding many quality features on top.
- AudioGen Model
- Multiband Diffusion
- Custom Model Support
- Generation Metadata and Audio Info tab
- Mono to Stereo
- Multiprompt/Prompt Segmentation with Structure Prompts
- Video Output Customization
- Music Continuation
## Installation
If you are updating from the previous version of AudioCraft Plus, do the following steps in the AudioCraft Plus folder:
```shell
git pull
pip install transformers --upgrade
pip install torchmetrics --upgrade
```
#### Otherwise: Clean Installation
AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following:
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
# Then proceed to one of the following
pip install -U audiocraft # stable release
pip install -U git+https://git@github.com/GrandaddyShmax/audiocraft_plus#egg=audiocraft # bleeding edge
pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
```
We also recommend having `ffmpeg` installed, either through your system or Anaconda:
```bash
sudo apt-get install ffmpeg
# Or if you are using Anaconda or Miniconda
conda install 'ffmpeg<5' -c conda-forge
```
Installation video thanks to Pogs Cafe:
[](http://www.youtube.com/watch?v=WjGk4bcbUOI "Installing MusicGen+ Locally")
Additional installation guide by [radaevm](https://github.com/radaevm) can be found [HERE](https://github.com/GrandaddyShmax/audiocraft_plus/discussions/31)
## Models
At the moment, AudioCraft contains the training code and inference code for:
* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
## Training code
AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
the [AudioCraft training documentation](./docs/TRAINING.md).
For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
that provides pointers to configuration, example grids and model/task-specific information and FAQ.
## API documentation
We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.
## FAQ
#### Is the training code available?
Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md).
#### Where are the models stored?
Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable.
## License
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
## Citation
For the general framework of AudioCraft, please cite the following.
```
@article{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
year={2023},
journal={arXiv preprint arXiv:2306.05284},
}
```
When referring to a specific model, please cite as mentioned in the model specific README, e.g
[./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.
================================================
FILE: app.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
# also released under the MIT license.
import argparse
from concurrent.futures import ProcessPoolExecutor
import os
from pathlib import Path
import subprocess as sp
from tempfile import NamedTemporaryFile
import time
import warnings
import glob
import re
from PIL import Image
from pydub import AudioSegment
from datetime import datetime
import json
import shutil
import taglib
import torch
import torchaudio
import gradio as gr
import numpy as np
import typing as tp
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion
from audiocraft.utils import ui
import random, string
version = "2.0.1"
theme = gr.themes.Base(
primary_hue="lime",
secondary_hue="lime",
neutral_hue="neutral",
).set(
button_primary_background_fill_hover='*primary_500',
button_primary_background_fill_hover_dark='*primary_500',
button_secondary_background_fill_hover='*primary_500',
button_secondary_background_fill_hover_dark='*primary_500'
)
MODEL = None # Last used model
MODELS = None
UNLOAD_MODEL = False
MOVE_TO_CPU = False
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
print(IS_BATCHED)
MAX_BATCH_SIZE = 12
BATCHED_DURATION = 15
INTERRUPTING = False
MBD = None
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
_old_call = sp.call
def generate_random_string(length):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
def resize_video(input_path, output_path, target_width, target_height):
ffmpeg_cmd = [
'ffmpeg',
'-y',
'-i', input_path,
'-vf', f'scale={target_width}:{target_height}',
'-c:a', 'copy',
output_path
]
sp.run(ffmpeg_cmd)
def _call_nostderr(*args, **kwargs):
# Avoid ffmpeg vomiting on the logs.
kwargs['stderr'] = sp.DEVNULL
kwargs['stdout'] = sp.DEVNULL
_old_call(*args, **kwargs)
sp.call = _call_nostderr
# Preallocating the pool of processes.
pool = ProcessPoolExecutor(4)
pool.__enter__()
def interrupt():
global INTERRUPTING
INTERRUPTING = True
class FileCleaner:
def __init__(self, file_lifetime: float = 3600):
self.file_lifetime = file_lifetime
self.files = []
def add(self, path: tp.Union[str, Path]):
self._cleanup()
self.files.append((time.time(), Path(path)))
def _cleanup(self):
now = time.time()
for time_added, path in list(self.files):
if now - time_added > self.file_lifetime:
if path.exists():
path.unlink()
self.files.pop(0)
else:
break
file_cleaner = FileCleaner()
def make_waveform(*args, **kwargs):
# Further remove some warnings.
be = time.time()
with warnings.catch_warnings():
warnings.simplefilter('ignore')
height = kwargs.pop('height')
width = kwargs.pop('width')
if height < 256:
height = 256
if width < 256:
width = 256
waveform_video = gr.make_waveform(*args, **kwargs)
out = f"{generate_random_string(12)}.mp4"
image = kwargs.get('bg_image', None)
if image is None:
resize_video(waveform_video, out, 900, 300)
else:
resize_video(waveform_video, out, width, height)
print("Make a video took", time.time() - be)
return out
def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=None, gen_type="music"):
global MODEL, MODELS
print("Loading model", version)
if MODELS is None:
if version == 'GrandaddyShmax/musicgen-custom':
MODEL = MusicGen.get_pretrained(custom_model)
else:
if gen_type == "music":
MODEL = MusicGen.get_pretrained(version)
elif gen_type == "audio":
MODEL = AudioGen.get_pretrained(version)
return
else:
t1 = time.monotonic()
if MODEL is not None:
MODEL.to('cpu') # move to cache
print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
t1 = time.monotonic()
if version != 'GrandaddyShmax/musicgen-custom' and MODELS.get(version) is None:
print("Loading model %s from disk" % version)
if gen_type == "music":
result = MusicGen.get_pretrained(version)
elif gen_type == "audio":
result = AudioGen.get_pretrained(version)
MODELS[version] = result
print("Model loaded in %.2fs" % (time.monotonic() - t1))
MODEL = result
return
result = MODELS[version].to('cuda')
print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
MODEL = result
def get_audio_info(audio_path):
if audio_path is not None:
if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
if not audio_path.name.endswith(".json"):
with taglib.File(audio_path.name, save_on_exit=False) as song:
if 'COMMENT' not in song.tags:
return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
json_string = song.tags['COMMENT'][0]
data = json.loads(json_string)
global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
key = str("\nKey: " + data['key']) if 'key' in data else ""
scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + decoder + topk + topp + temperature + cfg_coef)
if info == "":
return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
return info
else:
with open(audio_path.name) as json_file:
data = json.load(json_file)
#if 'global_prompt' not in data:
#return "No tags found. Either the file is not generated by MusicGen+ V1.2.8a and higher or the tags are corrupted."
global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
key = str("\nKey: " + data['key']) if 'key' in data else ""
scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + decoder + topk + topp + temperature + cfg_coef)
if info == "":
return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted."
return info
else:
return "Only .wav ,.mp4 and .json files are supported"
else:
return None
def info_to_params(audio_path):
if audio_path is not None:
if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
if not audio_path.name.endswith(".json"):
with taglib.File(audio_path.name, save_on_exit=False) as song:
if 'COMMENT' not in song.tags:
return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
json_string = song.tags['COMMENT'][0]
data = json.loads(json_string)
struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
model = data['model'] if 'model' in data else "large"
custom_model = (data['custom_model'] if (data['custom_model']) in get_available_folders() else None) if 'custom_model' in data else None
decoder = data['decoder'] if 'decoder' in data else "Default"
if 'texts' not in data:
unique_prompts = 1
text = ["", "", "", "", "", "", "", "", "", ""]
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
else:
s = data['texts']
s = re.findall(r"'(.*?)'", s)
text = []
repeat = []
i = 0
for elem in s:
if elem.strip():
if i == 0 or elem != s[i-1]:
text.append(elem)
repeat.append(1)
else:
repeat[-1] += 1
i += 1
text.extend([""] * (10 - len(text)))
repeat.extend([1] * (10 - len(repeat)))
unique_prompts = len([t for t in text if t])
audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
duration = int(data['duration']) if 'duration' in data else 10
topk = float(data['topk']) if 'topk' in data else 250
topp = float(data['topp']) if 'topp' in data else 0
temperature = float(data['temperature']) if 'temperature' in data else 1.0
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
seed = int(data['seed']) if 'seed' in data else -1
overlap = int(data['overlap']) if 'overlap' in data else 12
channel = data['channel'] if 'channel' in data else "stereo"
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
else:
with open(audio_path.name) as json_file:
data = json.load(json_file)
struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
model = data['model'] if 'model' in data else "large"
custom_model = (data['custom_model'] if data['custom_model'] in get_available_folders() else None) if 'custom_model' in data else None
decoder = data['decoder'] if 'decoder' in data else "Default"
if 'texts' not in data:
unique_prompts = 1
text = ["", "", "", "", "", "", "", "", "", ""]
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
else:
s = data['texts']
s = re.findall(r"'(.*?)'", s)
text = []
repeat = []
i = 0
for elem in s:
if elem.strip():
if i == 0 or elem != s[i-1]:
text.append(elem)
repeat.append(1)
else:
repeat[-1] += 1
i += 1
text.extend([""] * (10 - len(text)))
repeat.extend([1] * (10 - len(repeat)))
unique_prompts = len([t for t in text if t])
audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
duration = int(data['duration']) if 'duration' in data else 10
topk = float(data['topk']) if 'topk' in data else 250
topp = float(data['topp']) if 'topp' in data else 0
temperature = float(data['temperature']) if 'temperature' in data else 1.0
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
seed = int(data['seed']) if 'seed' in data else -1
overlap = int(data['overlap']) if 'overlap' in data else 12
channel = data['channel'] if 'channel' in data else "stereo"
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
else:
return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
else:
return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
def info_to_params_a(audio_path):
if audio_path is not None:
if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
if not audio_path.name.endswith(".json"):
with taglib.File(audio_path.name, save_on_exit=False) as song:
if 'COMMENT' not in song.tags:
return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
json_string = song.tags['COMMENT'][0]
data = json.loads(json_string)
struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
decoder = data['decoder'] if 'decoder' in data else "Default"
if 'texts' not in data:
unique_prompts = 1
text = ["", "", "", "", "", "", "", "", "", ""]
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
else:
s = data['texts']
s = re.findall(r"'(.*?)'", s)
text = []
repeat = []
i = 0
for elem in s:
if elem.strip():
if i == 0 or elem != s[i-1]:
text.append(elem)
repeat.append(1)
else:
repeat[-1] += 1
i += 1
text.extend([""] * (10 - len(text)))
repeat.extend([1] * (10 - len(repeat)))
unique_prompts = len([t for t in text if t])
duration = int(data['duration']) if 'duration' in data else 10
topk = float(data['topk']) if 'topk' in data else 250
topp = float(data['topp']) if 'topp' in data else 0
temperature = float(data['temperature']) if 'temperature' in data else 1.0
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
seed = int(data['seed']) if 'seed' in data else -1
overlap = int(data['overlap']) if 'overlap' in data else 12
channel = data['channel'] if 'channel' in data else "stereo"
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
else:
with open(audio_path.name) as json_file:
data = json.load(json_file)
struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
decoder = data['decoder'] if 'decoder' in data else "Default"
if 'texts' not in data:
unique_prompts = 1
text = ["", "", "", "", "", "", "", "", "", ""]
repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
else:
s = data['texts']
s = re.findall(r"'(.*?)'", s)
text = []
repeat = []
i = 0
for elem in s:
if elem.strip():
if i == 0 or elem != s[i-1]:
text.append(elem)
repeat.append(1)
else:
repeat[-1] += 1
i += 1
text.extend([""] * (10 - len(text)))
repeat.extend([1] * (10 - len(repeat)))
unique_prompts = len([t for t in text if t])
duration = int(data['duration']) if 'duration' in data else 10
topk = float(data['topk']) if 'topk' in data else 250
topp = float(data['topp']) if 'topp' in data else 0
temperature = float(data['temperature']) if 'temperature' in data else 1.0
cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
seed = int(data['seed']) if 'seed' in data else -1
overlap = int(data['overlap']) if 'overlap' in data else 12
channel = data['channel'] if 'channel' in data else "stereo"
sr_select = data['sr_select'] if 'sr_select' in data else "48000"
return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
else:
return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
else:
return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
def make_pseudo_stereo (filename, sr_select, pan, delay):
if pan:
temp = AudioSegment.from_wav(filename)
if sr_select != "32000":
temp = temp.set_frame_rate(int(sr_select))
left = temp.pan(-0.5) - 5
right = temp.pan(0.6) - 5
temp = left.overlay(right, position=5)
temp.export(filename, format="wav")
if delay:
waveform, sample_rate = torchaudio.load(filename) # load mono WAV file
delay_seconds = 0.01 # set delay 10ms
delay_samples = int(delay_seconds * sample_rate) # Calculating delay value in number of samples
stereo_waveform = torch.stack([waveform[0], torch.cat((torch.zeros(delay_samples), waveform[0][:-delay_samples]))]) # Generate a stereo file with original mono audio and delayed version
torchaudio.save(filename, stereo_waveform, sample_rate)
return
def normalize_audio(audio_data):
audio_data = audio_data.astype(np.float32)
max_value = np.max(np.abs(audio_data))
audio_data /= max_value
return audio_data
def load_diffusion():
global MBD
if MBD is None:
print("loading MBD")
MBD = MultiBandDiffusion.get_mbd_musicgen()
def unload_diffusion():
global MBD
if MBD is not None:
print("unloading MBD")
MBD = None
def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=False, **gen_kwargs):
if gen_type == "music":
maximum_size = 29.5
elif gen_type == "audio":
maximum_size = 9.5
cut_size = 0
input_length = 0
sampleP = None
if sample is not None:
globalSR, sampleM = sample[0], sample[1]
sampleM = normalize_audio(sampleM)
sampleM = torch.from_numpy(sampleM).t()
if sampleM.dim() == 1:
sampleM = sampleM.unsqueeze(0)
sample_length = sampleM.shape[sampleM.dim() - 1] / globalSR
if trim_start >= sample_length:
trim_start = sample_length - 0.5
if trim_end >= sample_length:
trim_end = sample_length - 0.5
if trim_start + trim_end >= sample_length:
tmp = sample_length - 0.5
trim_start = tmp / 2
trim_end = tmp / 2
sampleM = sampleM[..., int(globalSR * trim_start):int(globalSR * (sample_length - trim_end))]
sample_length = sample_length - (trim_start + trim_end)
if sample_length > maximum_size:
cut_size = sample_length - maximum_size
sampleP = sampleM[..., :int(globalSR * cut_size)]
sampleM = sampleM[..., int(globalSR * cut_size):]
if sample_length >= duration:
duration = sample_length + 0.5
input_length = sample_length
global MODEL
MODEL.set_generation_params(duration=(duration - cut_size), **gen_kwargs)
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies], [None if sample is None else (sample[0], sample[1].shape)])
be = time.time()
processed_melodies = []
if gen_type == "music":
target_sr = 32000
elif gen_type == "audio":
target_sr = 16000
target_ac = 1
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
if sample is not None:
if sampleP is None:
if gen_type == "music":
outputs = MODEL.generate_continuation(
prompt=sampleM,
prompt_sample_rate=globalSR,
descriptions=texts,
progress=progress,
return_tokens=USE_DIFFUSION
)
elif gen_type == "audio":
outputs = MODEL.generate_continuation(
prompt=sampleM,
prompt_sample_rate=globalSR,
descriptions=texts,
progress=progress
)
else:
if sampleP.dim() > 1:
sampleP = convert_audio(sampleP, globalSR, target_sr, target_ac)
sampleP = sampleP.to(MODEL.device).float().unsqueeze(0)
if gen_type == "music":
outputs = MODEL.generate_continuation(
prompt=sampleM,
prompt_sample_rate=globalSR,
descriptions=texts,
progress=progress,
return_tokens=USE_DIFFUSION
)
elif gen_type == "audio":
outputs = MODEL.generate_continuation(
prompt=sampleM,
prompt_sample_rate=globalSR,
descriptions=texts,
progress=progress
)
outputs = torch.cat([sampleP, outputs], 2)
elif any(m is not None for m in processed_melodies):
if gen_type == "music":
outputs = MODEL.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
return_tokens=USE_DIFFUSION
)
elif gen_type == "audio":
outputs = MODEL.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress
)
else:
if gen_type == "music":
outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
elif gen_type == "audio":
outputs = MODEL.generate(texts, progress=progress)
if USE_DIFFUSION:
print("outputs: " + str(outputs))
outputs_diffusion = MBD.tokens_to_wav(outputs[1])
outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
outputs = outputs.detach().cpu().float()
backups = outputs
if channel == "stereo":
outputs = convert_audio(outputs, target_sr, int(sr_select), 2)
elif channel == "mono" and sr_select != "32000":
outputs = convert_audio(outputs, target_sr, int(sr_select), 1)
out_files = []
out_audios = []
out_backup = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, output, (MODEL.sample_rate if channel == "stereo effect" else int(sr_select)), strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
if channel == "stereo effect":
make_pseudo_stereo(file.name, sr_select, pan=True, delay=True);
out_files.append(pool.submit(make_waveform, file.name, bg_image=image, bg_color=background, bars_color=(bar1, bar2), fg_alpha=1.0, bar_count=75, height=height, width=width))
out_audios.append(file.name)
file_cleaner.add(file.name)
print(f'wav: {file.name}')
for backup in backups:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, backup, MODEL.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
out_backup.append(file.name)
file_cleaner.add(file.name)
res = [out_file.result() for out_file in out_files]
res_audio = out_audios
res_backup = out_backup
for file in res:
file_cleaner.add(file)
print(f'video: {file}')
print("batch finished", len(texts), time.time() - be)
print("Tempfiles currently stored: ", len(file_cleaner.files))
if MOVE_TO_CPU:
MODEL.to('cpu')
if UNLOAD_MODEL:
MODEL = None
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
return res, res_audio, res_backup, input_length
def predict_batched(texts, melodies):
max_text_length = 512
texts = [text[:max_text_length] for text in texts]
load_model('melody')
res = _do_predictions(texts, melodies, BATCHED_DURATION)
return res
def add_tags(filename, tags):
json_string = None
data = {
"global_prompt": tags[0],
"bpm": tags[1],
"key": tags[2],
"scale": tags[3],
"texts": tags[4],
"duration": tags[5],
"overlap": tags[6],
"seed": tags[7],
"audio_mode": tags[8],
"input_length": tags[9],
"channel": tags[10],
"sr_select": tags[11],
"model": tags[12],
"custom_model": tags[13],
"decoder": tags[14],
"topk": tags[15],
"topp": tags[16],
"temperature": tags[17],
"cfg_coef": tags[18],
"generator": tags[19],
"version": version
}
json_string = json.dumps(data)
if os.path.exists(filename):
with taglib.File(filename, save_on_exit=True) as song:
song.tags = {'COMMENT': json_string }
json_file = open(tags[7] + '.json', 'w')
json_file.write(json_string)
json_file.close()
return json_file.name;
def save_outputs(mp4, wav_tmp, tags, gen_type):
# mp4: .mp4 file name in root running folder of app.py
# wav_tmp: temporary wav file located in %TEMP% folder
# seed - used seed
# exanple BgnJtr4Pn1AJ.mp4, C:\Users\Alex\AppData\Local\Temp\tmp4ermrebs.wav, 195123182343465
# procedure read generated .mp4 and wav files, rename it by using seed as name,
# and will store it to ./output/today_date/wav and ./output/today_date/mp4 folders.
# if file with same seed number already exist its make postfix in name like seed(n)
# where is n - consiqunce number 1-2-3-4 and so on
# then we store generated mp4 and wav into destination folders.
current_date = datetime.now().strftime("%Y%m%d")
wav_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'wav')
mp4_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'mp4')
json_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'json')
os.makedirs(wav_directory, exist_ok=True)
os.makedirs(mp4_directory, exist_ok=True)
os.makedirs(json_directory, exist_ok=True)
filename = str(tags[7]) + '.wav'
target = os.path.join(wav_directory, filename)
counter = 1
while os.path.exists(target):
filename = str(tags[7]) + f'({counter})' + '.wav'
target = os.path.join(wav_directory, filename)
counter += 1
shutil.copyfile(wav_tmp, target); # make copy of original file
json_file = add_tags(target, tags);
wav_target=target;
target=target.replace('wav', 'mp4');
mp4_target=target;
mp4=r'./' +mp4;
shutil.copyfile(mp4, target); # make copy of original file
_ = add_tags(target, tags);
target=target.replace('mp4', 'json'); # change the extension to json
json_target=target; # store the json target
with open(target, 'w') as f: # open a writable file object
shutil.copyfile(json_file, target); # make copy of original file
os.remove(json_file)
return wav_target, mp4_target, json_target;
def clear_cash():
# delete all temporary files genegated my system
current_date = datetime.now().date()
current_directory = os.getcwd()
files = glob.glob(os.path.join(current_directory, '*.mp4'))
for file in files:
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
if creation_date == current_date:
os.remove(file)
temp_directory = os.environ.get('TEMP')
files = glob.glob(os.path.join(temp_directory, 'tmp*.mp4'))
for file in files:
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
if creation_date == current_date:
os.remove(file)
files = glob.glob(os.path.join(temp_directory, 'tmp*.wav'))
for file in files:
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
if creation_date == current_date:
os.remove(file)
files = glob.glob(os.path.join(temp_directory, 'tmp*.png'))
for file in files:
creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
if creation_date == current_date:
os.remove(file)
return
def s2t(seconds, seconds2):
# convert seconds to time format
# seconds - time in seconds
# return time in format 00:00
m, s = divmod(seconds, 60)
m2, s2 = divmod(seconds2, 60)
if seconds != 0 and seconds < seconds2:
s = s + 1
return ("%02d:%02d - %02d:%02d" % (m, s, m2, s2))
def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9):
# calculate the time of generation
# overlap - overlap in seconds
# d0-d9 - drag
# return time in seconds
d_amount = [int(d0), int(d1), int(d2), int(d3), int(d4), int(d5), int(d6), int(d7), int(d8), int(d9)]
calc = []
tracks = []
time = 0
s = s - 1
max_time = duration
max_limit = 0
if gen_type == "music":
max_limit = 30
elif gen_type == "audio":
max_limit = 10
track_add = max_limit - overlap
tracks.append(max_limit + ((d_amount[0] - 1) * track_add))
for i in range(1, 10):
tracks.append(d_amount[i] * track_add)
if tracks[0] >= max_time or s == 0:
calc.append(s2t(time, max_time))
time = max_time
else:
calc.append(s2t(time, tracks[0]))
time = tracks[0]
for i in range(1, 10):
if time + tracks[i] >= max_time or i == s:
calc.append(s2t(time, max_time))
time = max_time
else:
calc.append(s2t(time, time + tracks[i]))
time = time + tracks[i]
return calc[0], calc[1], calc[2], calc[3], calc[4], calc[5], calc[6], calc[7], calc[8], calc[9]
def predict_full(gen_type, model, decoder, custom_model, prompt_amount, struc_prompt, bpm, key, scale, global_prompt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select, progress=gr.Progress()):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
if gen_type == "audio":
custom_model = None
custom_model_shrt = "none"
elif gen_type == "music":
custom_model_shrt = custom_model
custom_model = "models/" + custom_model
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
if trim_start < 0:
trim_start = 0
if trim_end < 0:
trim_end = 0
topk = int(topk)
if decoder == "MultiBand_Diffusion":
USE_DIFFUSION = True
load_diffusion()
else:
USE_DIFFUSION = False
unload_diffusion()
if gen_type == "music":
model_shrt = model
model = "GrandaddyShmax/musicgen-" + model
elif gen_type == "audio":
model_shrt = model
model = "GrandaddyShmax/audiogen-" + model
if MODEL is None or MODEL.name != (model):
load_model(model, custom_model, gen_type)
else:
if MOVE_TO_CPU:
MODEL.to('cuda')
if seed < 0:
seed = random.randint(0, 0xffff_ffff_ffff)
torch.manual_seed(seed)
def _progress(generated, to_generate):
progress((min(generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
MODEL.set_custom_progress_callback(_progress)
audio_mode = "none"
melody = None
sample = None
if audio:
audio_mode = mode
if mode == "sample":
sample = audio
elif mode == "melody":
melody = audio
custom_model_shrt = "none" if model != "GrandaddyShmax/musicgen-custom" else custom_model_shrt
text_cat = [p0, p1, p2, p3, p4, p5, p6, p7, p8, p9]
drag_cat = [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9]
texts = []
raw_texts = []
ind = 0
ind2 = 0
while ind < prompt_amount:
for ind2 in range(int(drag_cat[ind])):
if not struc_prompt:
texts.append(text_cat[ind])
global_prompt = "none"
bpm = "none"
key = "none"
scale = "none"
raw_texts.append(text_cat[ind])
else:
if gen_type == "music":
bpm_str = str(bpm) + " bpm"
key_str = ", " + str(key) + " " + str(scale)
global_str = (", " + str(global_prompt)) if str(global_prompt) != "" else ""
elif gen_type == "audio":
bpm_str = ""
key_str = ""
global_str = (str(global_prompt)) if str(global_prompt) != "" else ""
texts_str = (", " + str(text_cat[ind])) if str(text_cat[ind]) != "" else ""
texts.append(bpm_str + key_str + global_str + texts_str)
raw_texts.append(text_cat[ind])
ind2 = 0
ind = ind + 1
outs, outs_audio, outs_backup, input_length = _do_predictions(
gen_type, [texts], [melody], sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=True,
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, extend_stride=MODEL.max_duration-overlap)
tags = [str(global_prompt), str(bpm), str(key), str(scale), str(raw_texts), str(duration), str(overlap), str(seed), str(audio_mode), str(input_length), str(channel), str(sr_select), str(model_shrt), str(custom_model_shrt), str(decoder), str(topk), str(topp), str(temperature), str(cfg_coef), str(gen_type)]
wav_target, mp4_target, json_target = save_outputs(outs[0], outs_audio[0], tags, gen_type);
# Removes the temporary files.
for out in outs:
os.remove(out)
for out in outs_audio:
os.remove(out)
return mp4_target, wav_target, outs_backup[0], [mp4_target, wav_target, json_target], seed
max_textboxes = 10
#def get_available_models():
#return sorted([re.sub('.pt$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('.pt')])
def get_available_folders():
models_dir = "models"
folders = [f for f in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, f))]
return sorted(folders)
def toggle_audio_src(choice):
if choice == "mic":
return gr.update(source="microphone", value=None, label="Microphone")
else:
return gr.update(source="upload", value=None, label="File")
def ui_full(launch_kwargs):
with gr.Blocks(title='AudioCraft Plus', theme=theme) as interface:
gr.Markdown(
"""
# AudioCraft Plus - v2.0.1
### An All-in-One AudioCraft WebUI
Thanks to: facebookresearch, Camenduru, rkfg, oobabooga, AlexHK and GrandaddyShmax
"""
)
with gr.Tab("MusicGen"):
gr.Markdown(
"""
### MusicGen
"""
)
with gr.Row():
with gr.Column():
with gr.Tab("Generation"):
with gr.Accordion("Structure Prompts", open=False):
with gr.Column():
with gr.Row():
struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0)
key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True)
scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True)
with gr.Row():
global_prompt = gr.Text(label="Global Prompt", interactive=True, scale=3)
with gr.Row():
s = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
#s_mode = gr.Radio(["segmentation", "batch"], value="segmentation", interactive=True, scale=1, label="Generation Mode")
with gr.Column():
textboxes = []
prompts = []
repeats = []
calcs = []
with gr.Row():
text0 = gr.Text(label="Input Text", interactive=True, scale=4)
prompts.append(text0)
drag0 = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
repeats.append(drag0)
calc0 = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
calcs.append(calc0)
for i in range(max_textboxes):
with gr.Row(visible=False) as t:
text = gr.Text(label="Input Text", interactive=True, scale=3)
repeat = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
calc = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
textboxes.append(t)
prompts.append(text)
repeats.append(repeat)
calcs.append(calc)
to_calc = gr.Button("Calculate Timings", variant="secondary")
with gr.Row():
duration = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
with gr.Row():
overlap = gr.Slider(minimum=1, maximum=29, value=12, step=1, label="Overlap", interactive=True)
with gr.Row():
seed = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed], queue=False)
reuse_seed = gr.Button('\u267b\ufe0f', scale=1)
with gr.Tab("Audio"):
with gr.Row():
with gr.Column():
input_type = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
mode = gr.Radio(["melody", "sample"], label="Input Audio Mode (optional)", value="sample", interactive=True)
with gr.Row():
trim_start = gr.Number(label="Trim Start", value=0, interactive=True)
trim_end = gr.Number(label="Trim End", value=0, interactive=True)
audio = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
with gr.Tab("Customization"):
with gr.Row():
with gr.Column():
background = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
bar1 = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
bar2 = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
with gr.Column():
image = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
with gr.Row():
height = gr.Number(label="Height", value=512, interactive=True)
width = gr.Number(label="Width", value=768, interactive=True)
with gr.Tab("Settings"):
with gr.Row():
channel = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
sr_select = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
with gr.Row():
model = gr.Radio(["melody", "small", "medium", "large", "custom"], label="Model", value="large", interactive=True, scale=1)
with gr.Column():
dropdown = gr.Dropdown(choices=get_available_folders(), value=("No models found" if len(get_available_folders()) < 1 else get_available_folders()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True)
ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_folders()}, 'refresh-button')
with gr.Row():
decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True)
with gr.Row():
topk = gr.Number(label="Top-k", value=250, interactive=True)
topp = gr.Number(label="Top-p", value=0, interactive=True)
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
with gr.Row():
submit = gr.Button("Generate", variant="primary")
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
with gr.Column() as c:
with gr.Tab("Output"):
output = gr.Video(label="Generated Music", scale=0)
with gr.Row():
audio_only = gr.Audio(type="numpy", label="Audio Only", interactive=False)
backup_only = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
send_audio = gr.Button("Send to Input Audio")
seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
download = gr.File(label="Generated Files", interactive=False)
with gr.Tab("Wiki"):
gr.Markdown(
"""
- **[Generate (button)]:**
Generates the music with the given settings and prompts.
- **[Interrupt (button)]:**
Stops the music generation as soon as it can, providing an incomplete output.
---
### Generation Tab:
#### Structure Prompts:
This feature helps reduce repetetive prompts by allowing you to set global prompts
that will be used for all prompt segments.
- **[Structure Prompts (checkbox)]:**
Enable/Disable the structure prompts feature.
- **[BPM (number)]:**
Beats per minute of the generated music.
- **[Key (dropdown)]:**
The key of the generated music.
- **[Scale (dropdown)]:**
The scale of the generated music.
- **[Global Prompt (text)]:**
Here write the prompt that you wish to be used for all prompt segments.
#### Multi-Prompt:
This feature allows you to control the music, adding variation to different time segments.
You have up to 10 prompt segments. the first prompt will always be 30s long
the other prompts will be [30s - overlap].
for example if the overlap is 10s, each prompt segment will be 20s.
- **[Prompt Segments (number)]:**
Amount of unique prompt to generate throughout the music generation.
- **[Prompt/Input Text (prompt)]:**
Here describe the music you wish the model to generate.
- **[Repeat (number)]:**
Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
- **[Time (text)]:**
The time of the prompt segment.
- **[Calculate Timings (button)]:**
Calculates the timings of the prompt segments.
- **[Duration (number)]:**
How long you want the generated music to be (in seconds).
- **[Overlap (number)]:**
How much each new segment will reference the previous segment (in seconds).
For example, if you choose 20s: Each new segment after the first one will reference the previous segment 20s
and will generate only 10s of new music. The model can only process 30s of music.
- **[Seed (number)]:**
Your generated music id. If you wish to generate the exact same music,
place the exact seed with the exact prompts
(This way you can also extend specific song that was generated short).
- **[Random Seed (button)]:**
Gives "-1" as a seed, which counts as a random seed.
- **[Copy Previous Seed (button)]:**
Copies the seed from the output seed (if you don't feel like doing it manualy).
---
### Audio Tab:
- **[Input Type (selection)]:**
`File` mode allows you to upload an audio file to use as input
`Mic` mode allows you to use your microphone as input
- **[Input Audio Mode (selection)]:**
`Melody` mode only works with the melody model: it conditions the music generation to reference the melody
`Sample` mode works with any model: it gives a music sample to the model to generate its continuation.
- **[Trim Start and Trim End (numbers)]:**
`Trim Start` set how much you'd like to trim the input audio from the start
`Trim End` same as the above but from the end
- **[Input Audio (audio file)]:**
Input here the audio you wish to use with "melody" or "sample" mode.
---
### Customization Tab:
- **[Background Color (color)]:**
Works only if you don't upload image. Color of the background of the waveform.
- **[Bar Color Start (color)]:**
First color of the waveform bars.
- **[Bar Color End (color)]:**
Second color of the waveform bars.
- **[Background Image (image)]:**
Background image that you wish to be attached to the generated video along with the waveform.
- **[Height and Width (numbers)]:**
Output video resolution, only works with image.
(minimum height and width is 256).
---
### Settings Tab:
- **[Output Audio Channels (selection)]:**
With this you can select the amount of channels that you wish for your output audio.
`mono` is a straightforward single channel audio
`stereo` is a dual channel audio but it will sound more or less like mono
`stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
- **[Output Audio Sample Rate (dropdown)]:**
The output audio sample rate, the model default is 32000.
- **[Model (selection)]:**
Here you can choose which model you wish to use:
`melody` model is based on the medium model with a unique feature that lets you use melody conditioning
`small` model is trained on 300M parameters
`medium` model is trained on 1.5B parameters
`large` model is trained on 3.3B parameters
`custom` model runs the custom model that you provided.
- **[Custom Model (selection)]:**
This dropdown will show you models that are placed in the `models` folder
you must select `custom` in the model options in order to use it.
- **[Refresh (button)]:**
Refreshes the dropdown list for custom model.
- **[Decoder (selection)]:**
Choose here the decoder that you wish to use:
`Default` is the default decoder
`MultiBand_Diffusion` is a decoder that uses diffusion to generate the audio.
- **[Top-k (number)]:**
is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
- **[Top-p (number)]:**
also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
- **[Temperature (number)]:**
is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
- **[Classifier Free Guidance (number)]:**
refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
"""
)
with gr.Tab("AudioGen"):
gr.Markdown(
"""
### AudioGen
"""
)
with gr.Row():
with gr.Column():
with gr.Tab("Generation"):
with gr.Accordion("Structure Prompts", open=False):
with gr.Row():
struc_prompts_a = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
global_prompt_a = gr.Text(label="Global Prompt", interactive=True, scale=3)
with gr.Row():
s_a = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
with gr.Column():
textboxes_a = []
prompts_a = []
repeats_a = []
calcs_a = []
with gr.Row():
text0_a = gr.Text(label="Input Text", interactive=True, scale=4)
prompts_a.append(text0_a)
drag0_a = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
repeats_a.append(drag0_a)
calc0_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
calcs_a.append(calc0_a)
for i in range(max_textboxes):
with gr.Row(visible=False) as t_a:
text_a = gr.Text(label="Input Text", interactive=True, scale=3)
repeat_a = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
calc_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
textboxes_a.append(t_a)
prompts_a.append(text_a)
repeats_a.append(repeat_a)
calcs_a.append(calc_a)
to_calc_a = gr.Button("Calculate Timings", variant="secondary")
with gr.Row():
duration_a = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
with gr.Row():
overlap_a = gr.Slider(minimum=1, maximum=9, value=2, step=1, label="Overlap", interactive=True)
with gr.Row():
seed_a = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed_a], queue=False)
reuse_seed_a = gr.Button('\u267b\ufe0f', scale=1)
with gr.Tab("Audio"):
with gr.Row():
with gr.Column():
input_type_a = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
mode_a = gr.Radio(["sample"], label="Input Audio Mode (optional)", value="sample", interactive=False, visible=False)
with gr.Row():
trim_start_a = gr.Number(label="Trim Start", value=0, interactive=True)
trim_end_a = gr.Number(label="Trim End", value=0, interactive=True)
audio_a = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)
with gr.Tab("Customization"):
with gr.Row():
with gr.Column():
background_a = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
bar1_a = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
bar2_a = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
with gr.Column():
image_a = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
with gr.Row():
height_a = gr.Number(label="Height", value=512, interactive=True)
width_a = gr.Number(label="Width", value=768, interactive=True)
with gr.Tab("Settings"):
with gr.Row():
channel_a = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
sr_select_a = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
with gr.Row():
model_a = gr.Radio(["medium"], label="Model", value="medium", interactive=False, visible=False)
decoder_a = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False, visible=False)
with gr.Row():
topk_a = gr.Number(label="Top-k", value=250, interactive=True)
topp_a = gr.Number(label="Top-p", value=0, interactive=True)
temperature_a = gr.Number(label="Temperature", value=1.0, interactive=True)
cfg_coef_a = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
with gr.Row():
submit_a = gr.Button("Generate", variant="primary")
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
with gr.Column():
with gr.Tab("Output"):
output_a = gr.Video(label="Generated Audio", scale=0)
with gr.Row():
audio_only_a = gr.Audio(type="numpy", label="Audio Only", interactive=False)
backup_only_a = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
send_audio_a = gr.Button("Send to Input Audio")
seed_used_a = gr.Number(label='Seed used', value=-1, interactive=False)
download_a = gr.File(label="Generated Files", interactive=False)
with gr.Tab("Wiki"):
gr.Markdown(
"""
- **[Generate (button)]:**
Generates the audio with the given settings and prompts.
- **[Interrupt (button)]:**
Stops the audio generation as soon as it can, providing an incomplete output.
---
### Generation Tab:
#### Structure Prompts:
This feature helps reduce repetetive prompts by allowing you to set global prompts
that will be used for all prompt segments.
- **[Structure Prompts (checkbox)]:**
Enable/Disable the structure prompts feature.
- **[Global Prompt (text)]:**
Here write the prompt that you wish to be used for all prompt segments.
#### Multi-Prompt:
This feature allows you to control the audio, adding variation to different time segments.
You have up to 10 prompt segments. the first prompt will always be 10s long
the other prompts will be [10s - overlap].
for example if the overlap is 2s, each prompt segment will be 8s.
- **[Prompt Segments (number)]:**
Amount of unique prompt to generate throughout the audio generation.
- **[Prompt/Input Text (prompt)]:**
Here describe the audio you wish the model to generate.
- **[Repeat (number)]:**
Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).
- **[Time (text)]:**
The time of the prompt segment.
- **[Calculate Timings (button)]:**
Calculates the timings of the prompt segments.
- **[Duration (number)]:**
How long you want the generated audio to be (in seconds).
- **[Overlap (number)]:**
How much each new segment will reference the previous segment (in seconds).
For example, if you choose 2s: Each new segment after the first one will reference the previous segment 2s
and will generate only 8s of new audio. The model can only process 10s of music.
- **[Seed (number)]:**
Your generated audio id. If you wish to generate the exact same audio,
place the exact seed with the exact prompts
(This way you can also extend specific song that was generated short).
- **[Random Seed (button)]:**
Gives "-1" as a seed, which counts as a random seed.
- **[Copy Previous Seed (button)]:**
Copies the seed from the output seed (if you don't feel like doing it manualy).
---
### Audio Tab:
- **[Input Type (selection)]:**
`File` mode allows you to upload an audio file to use as input
`Mic` mode allows you to use your microphone as input
- **[Trim Start and Trim End (numbers)]:**
`Trim Start` set how much you'd like to trim the input audio from the start
`Trim End` same as the above but from the end
- **[Input Audio (audio file)]:**
Input here the audio you wish to use.
---
### Customization Tab:
- **[Background Color (color)]:**
Works only if you don't upload image. Color of the background of the waveform.
- **[Bar Color Start (color)]:**
First color of the waveform bars.
- **[Bar Color End (color)]:**
Second color of the waveform bars.
- **[Background Image (image)]:**
Background image that you wish to be attached to the generated video along with the waveform.
- **[Height and Width (numbers)]:**
Output video resolution, only works with image.
(minimum height and width is 256).
---
### Settings Tab:
- **[Output Audio Channels (selection)]:**
With this you can select the amount of channels that you wish for your output audio.
`mono` is a straightforward single channel audio
`stereo` is a dual channel audio but it will sound more or less like mono
`stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.
- **[Output Audio Sample Rate (dropdown)]:**
The output audio sample rate, the model default is 32000.
- **[Top-k (number)]:**
is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.
- **[Top-p (number)]:**
also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
- **[Temperature (number)]:**
is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.
- **[Classifier Free Guidance (number)]:**
refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
"""
)
with gr.Tab("Audio Info"):
gr.Markdown(
"""
### Audio Info
"""
)
with gr.Row():
with gr.Column():
in_audio = gr.File(type="file", label="Input Any Audio", interactive=True)
with gr.Row():
send_gen = gr.Button("Send to MusicGen", variant="primary")
send_gen_a = gr.Button("Send to AudioGen", variant="primary")
with gr.Column():
info = gr.Textbox(label="Audio Info", lines=10, interactive=False)
with gr.Tab("Changelog"):
gr.Markdown(
"""
## Changelog:
### v2.0.1
- Changed custom model loading to support the official trained models
- Additional changes from the main facebookresearch repo
### v2.0.0a
- Forgot to move all the update to app.py from temp2.py... oops
### v2.0.0
- Changed name from MusicGen+ to AudioCraft Plus
- Complete overhaul of the repo "backend" with the latest changes from the main facebookresearch repo
- Added a new decoder: MultiBand_Diffusion
- Added AudioGen: a new tab for generating audio
### v1.2.8c
- Implemented Reverse compatibility for audio info tab with previous versions
### v1.2.8b
- Fixed the error when loading default models
### v1.2.8a
- Adapted Audio info tab to work with the new structure prompts feature
- Now custom models actually work, make sure you select the correct base model
### v1.2.8
- Now you will also recieve json file with metadata of generated audio
- Added error messages in Audio Info tab
- Added structure prompts: you can select bpm, key and global prompt for all prompts
- Added time display next to each prompt, can be calculated with "Calculate Timings" button
### v1.2.7
- When sending generated audio to Input Audio, it will send a backup audio with default settings
(best for continuos generation)
- Added Metadata to generated audio (Thanks to AlexHK ♥)
- Added Audio Info tab that will display the metadata of the input audio
- Added "send to Text2Audio" button in Audio Info tab
- Generated audio is now stored in the "output" folder (Thanks to AlexHK ♥)
- Added an output area with generated files and download buttons
- Enhanced Stereo effect (Thanks to AlexHK ♥)
### v1.2.6
- Added option to generate in stereo (instead of only mono)
- Added dropdown for selecting output sample rate (model default is 32000)
### v1.2.5a
- Added file cleaner (This comes from the main facebookresearch repo)
- Reorganized a little, moved audio to a seperate tab
### v1.2.5
- Gave a unique lime theme to the webui
- Added additional output for audio only
- Added button to send generated audio to Input Audio
- Added option to trim Input Audio
### v1.2.4
- Added mic input (This comes from the main facebookresearch repo)
### v1.2.3
- Added option to change video size to fit the image you upload
### v1.2.2
- Added Wiki, Changelog and About tabs
### v1.2.1
- Added tabs and organized the entire interface
- Added option to attach image to the output video
- Added option to load fine-tuned models (Yet to be tested)
### v1.2.0
- Added Multi-Prompt
### v1.1.3
- Added customization options for generated waveform
### v1.1.2
- Removed sample length limit: now you can input audio of any length as music sample
### v1.1.1
- Improved music sample audio quality when using music continuation
### v1.1.0
- Rebuilt the repo on top of the latest structure of the main MusicGen repo
- Improved Music continuation feature
### v1.0.0 - Stable Version
- Added Music continuation
"""
)
with gr.Tab("About"):
gen_type = gr.Text(value="music", interactive=False, visible=False)
gen_type_a = gr.Text(value="audio", interactive=False, visible=False)
gr.Markdown(
"""
This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
## MusicGen+ is an extended version of the original MusicGen by facebookresearch.
### Repo: https://github.com/GrandaddyShmax/audiocraft_plus/tree/plus
---
### This project was possible thanks to:
#### GrandaddyShmax - https://github.com/GrandaddyShmax
#### Camenduru - https://github.com/camenduru
#### rkfg - https://github.com/rkfg
#### oobabooga - https://github.com/oobabooga
#### AlexHK - https://github.com/alanhk147
"""
)
send_gen.click(info_to_params, inputs=[in_audio], outputs=[decoder, struc_prompts, global_prompt, bpm, key, scale, model, dropdown, s, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select], queue=False)
reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
send_audio.click(fn=lambda x: x, inputs=[backup_only], outputs=[audio], queue=False)
submit.click(predict_full, inputs=[gen_type, model, decoder, dropdown, s, struc_prompts, bpm, key, scale, global_prompt, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select], outputs=[output, audio_only, backup_only, download, seed_used])
input_type.change(toggle_audio_src, input_type, [audio], queue=False, show_progress=False)
to_calc.click(calc_time, inputs=[gen_type, s, duration, overlap, repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9]], outputs=[calcs[0], calcs[1], calcs[2], calcs[3], calcs[4], calcs[5], calcs[6], calcs[7], calcs[8], calcs[9]], queue=False)
send_gen_a.click(info_to_params_a, inputs=[in_audio], outputs=[decoder_a, struc_prompts_a, global_prompt_a, s_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, channel_a, sr_select_a], queue=False)
reuse_seed_a.click(fn=lambda x: x, inputs=[seed_used_a], outputs=[seed_a], queue=False)
send_audio_a.click(fn=lambda x: x, inputs=[backup_only_a], outputs=[audio_a], queue=False)
submit_a.click(predict_full, inputs=[gen_type_a, model_a, decoder_a, dropdown, s_a, struc_prompts_a, bpm, key, scale, global_prompt_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], audio_a, mode_a, trim_start_a, trim_end_a, duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, image_a, height_a, width_a, background_a, bar1_a, bar2_a, channel_a, sr_select_a], outputs=[output_a, audio_only_a, backup_only_a, download_a, seed_used_a])
input_type_a.change(toggle_audio_src, input_type_a, [audio_a], queue=False, show_progress=False)
to_calc_a.click(calc_time, inputs=[gen_type_a, s_a, duration_a, overlap_a, repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9]], outputs=[calcs_a[0], calcs_a[1], calcs_a[2], calcs_a[3], calcs_a[4], calcs_a[5], calcs_a[6], calcs_a[7], calcs_a[8], calcs_a[9]], queue=False)
in_audio.change(get_audio_info, in_audio, outputs=[info])
def variable_outputs(k):
k = int(k) - 1
return [gr.Textbox.update(visible=True)]*k + [gr.Textbox.update(visible=False)]*(max_textboxes-k)
def get_size(image):
if image is not None:
img = Image.open(image)
img_height = img.height
img_width = img.width
if (img_height%2) != 0:
img_height = img_height + 1
if (img_width%2) != 0:
img_width = img_width + 1
return img_height, img_width
else:
return 512, 768
image.change(get_size, image, outputs=[height, width])
image_a.change(get_size, image_a, outputs=[height_a, width_a])
s.change(variable_outputs, s, textboxes)
s_a.change(variable_outputs, s_a, textboxes_a)
interface.queue().launch(**launch_kwargs)
def ui_batched(launch_kwargs):
with gr.Blocks() as demo:
gr.Markdown(
"""
# MusicGen
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
a simple and controllable model for music generation
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
for longer sequences, more control and no queue.
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Text(label="Describe your music", lines=2, interactive=True)
with gr.Column():
radio = gr.Radio(["file", "mic"], value="file",
label="Condition on a melody (optional) File or Mic")
melody = gr.Audio(source="upload", type="numpy", label="File",
interactive=True, elem_id="melody-input")
with gr.Row():
submit = gr.Button("Generate")
with gr.Column():
output = gr.Video(label="Generated Music")
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
submit.click(predict_batched, inputs=[text, melody],
outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
gr.Examples(
fn=predict_batched,
examples=[
[
"An 80s driving pop song with heavy drums and synth pads in the background",
"./assets/bach.mp3",
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
],
[
"90s rock song with electric guitar and heavy drums",
None,
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
"./assets/bach.mp3",
],
[
"lofi slow bpm electro chill with organic samples",
None,
],
],
inputs=[text, melody],
outputs=[output]
)
gr.Markdown("""
### More details
The model will generate 12 seconds of audio based on the description you provided.
You can optionally provide a reference audio from which a broad melody will be extracted.
The model will then try to follow both the description and melody provided.
All samples are generated with the `melody` model.
You can also use your own GPU or a Google Colab by following the instructions on our repo.
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
for more details.
""")
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
type=str,
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
help='IP to listen on for connections to Gradio',
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
parser.add_argument(
'--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory'
)
parser.add_argument(
'--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)'
)
parser.add_argument(
'--cache', action='store_true', help='Cache models in RAM to quickly switch between them'
)
args = parser.parse_args()
UNLOAD_MODEL = args.unload_model
MOVE_TO_CPU = args.unload_to_cpu
if args.cache:
MODELS = {}
launch_kwargs = {}
launch_kwargs['server_name'] = args.listen
if args.username and args.password:
launch_kwargs['auth'] = (args.username, args.password)
if args.server_port:
launch_kwargs['server_port'] = args.server_port
if args.inbrowser:
launch_kwargs['inbrowser'] = args.inbrowser
if args.share:
launch_kwargs['share'] = args.share
# Show the interface
if IS_BATCHED:
global USE_DIFFUSION
USE_DIFFUSION = False
ui_batched(launch_kwargs)
else:
ui_full(launch_kwargs)
================================================
FILE: audiocraft/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
AudioCraft is a general framework for training audio generative models.
At the moment we provide the training code for:
- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
text-to-music and melody+text autoregressive generative model.
For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
`audiocraft.models.musicgen.MusicGen`.
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
text-to-general-audio generative model.
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
neural audio codec which provides an excellent tokenizer for autoregressive language models.
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
"""
# flake8: noqa
from . import data, modules, models
__version__ = '1.0.0'
================================================
FILE: audiocraft/adversarial/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Adversarial losses and discriminator architectures."""
# flake8: noqa
from .discriminators import (
MultiPeriodDiscriminator,
MultiScaleDiscriminator,
MultiScaleSTFTDiscriminator
)
from .losses import (
AdversarialLoss,
AdvLossType,
get_adv_criterion,
get_fake_criterion,
get_real_criterion,
FeatLossType,
FeatureMatchingLoss
)
================================================
FILE: audiocraft/adversarial/discriminators/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .mpd import MultiPeriodDiscriminator
from .msd import MultiScaleDiscriminator
from .msstftd import MultiScaleSTFTDiscriminator
================================================
FILE: audiocraft/adversarial/discriminators/base.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import typing as tp
import torch
import torch.nn as nn
FeatureMapType = tp.List[torch.Tensor]
LogitsType = torch.Tensor
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
class MultiDiscriminator(ABC, nn.Module):
"""Base implementation for discriminators composed of sub-discriminators acting at different scales.
"""
def __init__(self):
super().__init__()
@abstractmethod
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
...
@property
@abstractmethod
def num_discriminators(self) -> int:
"""Number of discriminators.
"""
...
================================================
FILE: audiocraft/adversarial/discriminators/mpd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class PeriodDiscriminator(nn.Module):
"""Period sub-discriminator.
Args:
period (int): Period between samples of audio.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_layers (int): Number of convolutional layers.
kernel_sizes (list of int): Kernel sizes for convolutions.
stride (int): Stride for convolutions.
filters (int): Initial number of filters in convolutions.
filters_scale (int): Multiplier of number of filters as we increase depth.
max_filters (int): Maximum number of filters.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
"""
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
self.period = period
self.n_layers = n_layers
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList()
in_chs = in_channels
for i in range(self.n_layers):
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
eff_stride = 1 if i == self.n_layers - 1 else stride
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
in_chs = out_chs
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), 'reflect')
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for conv in self.convs:
x = conv(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
# x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(MultiDiscriminator):
"""Multi-Period (MPD) Discriminator.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
**kwargs: Additional args for `PeriodDiscriminator`
"""
def __init__(self, in_channels: int = 1, out_channels: int = 1,
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
super().__init__()
self.discriminators = nn.ModuleList([
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
])
@property
def num_discriminators(self):
return len(self.discriminators)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
================================================
FILE: audiocraft/adversarial/discriminators/msd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import numpy as np
import torch
import torch.nn as nn
from ...modules import NormConv1d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
class ScaleDiscriminator(nn.Module):
"""Waveform sub-discriminator.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
filters (int): Number of initial filters for convolutions.
max_filters (int): Maximum number of filters.
downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
groups (Sequence[int] or None): Groups for inner convolutions.
strides (Sequence[int] or None): Strides for inner convolutions.
paddings (Sequence[int] or None): Paddings for inner convolutions.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
pad (str): Padding for initial convolution.
pad_params (dict): Parameters to provide to the padding module.
"""
def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
pad_params: dict = {}):
super().__init__()
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1
assert kernel_sizes[1] % 2 == 1
assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
assert (groups is None or len(groups) == len(downsample_scales))
assert (strides is None or len(strides) == len(downsample_scales))
assert (paddings is None or len(paddings) == len(downsample_scales))
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList()
self.convs.append(
nn.Sequential(
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
)
)
in_chs = filters
for i, downsample_scale in enumerate(downsample_scales):
out_chs = min(in_chs * downsample_scale, max_filters)
default_kernel_size = downsample_scale * 10 + 1
default_stride = downsample_scale
default_padding = (default_kernel_size - 1) // 2
default_groups = in_chs // 4
self.convs.append(
NormConv1d(in_chs, out_chs,
kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
stride=strides[i] if strides else default_stride,
groups=groups[i] if groups else default_groups,
padding=paddings[i] if paddings else default_padding,
norm=norm))
in_chs = out_chs
out_chs = min(in_chs * 2, max_filters)
self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
padding=(kernel_sizes[0] - 1) // 2, norm=norm))
self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
padding=(kernel_sizes[1] - 1) // 2, norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
for layer in self.convs:
x = layer(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
# x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(MultiDiscriminator):
"""Multi-Scale (MSD) Discriminator,
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
downsample_factor (int): Downsampling factor between the different scales.
scale_norms (Sequence[str]): Normalization for each sub-discriminator.
**kwargs: Additional args for ScaleDiscriminator.
"""
def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
super().__init__()
self.discriminators = nn.ModuleList([
ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
])
self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
@property
def num_discriminators(self):
return len(self.discriminators)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for i, disc in enumerate(self.discriminators):
if i != 0:
self.downsample(x)
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
================================================
FILE: audiocraft/adversarial/discriminators/msstftd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torchaudio
import torch
from torch import nn
from einops import rearrange
from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
class DiscriminatorSTFT(nn.Module):
"""STFT sub-discriminator.
Args:
filters (int): Number of filters in convolutions.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_fft (int): Size of FFT for each scale.
hop_length (int): Length of hop between STFT windows for each scale.
kernel_size (tuple of int): Inner Conv2d kernel sizes.
stride (tuple of int): Inner Conv2d strides.
dilations (list of int): Inner Conv2d dilation on the time dimension.
win_length (int): Window size for each scale.
normalized (bool): Whether to normalize by magnitude after stft.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
growth (int): Growth factor for the filters.
"""
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
self.filters = filters
self.in_channels = in_channels
self.out_channels = out_channels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized
self.activation = getattr(torch.nn, activation)(**activation_params)
self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
normalized=self.normalized, center=False, pad_mode=None, power=None)
spec_channels = 2 * self.in_channels
self.convs = nn.ModuleList()
self.convs.append(
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
)
in_chs = min(filters_scale * self.filters, max_filters)
for i, dilation in enumerate(dilations):
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
norm=norm))
in_chs = out_chs
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm))
self.conv_post = NormConv2d(out_chs, self.out_channels,
kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
z = torch.cat([z.real, z.imag], dim=1)
z = rearrange(z, 'b c w t -> b c t w')
for i, layer in enumerate(self.convs):
z = layer(z)
z = self.activation(z)
fmap.append(z)
z = self.conv_post(z)
return z, fmap
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
"""Multi-Scale STFT (MS-STFT) discriminator.
Args:
filters (int): Number of filters in convolutions.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
sep_channels (bool): Separate channels to distinct samples for stereo support.
n_ffts (Sequence[int]): Size of FFT for each scale.
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
win_lengths (Sequence[int]): Window size for each scale.
**kwargs: Additional args for STFTDiscriminator.
"""
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
super().__init__()
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
self.sep_channels = sep_channels
self.discriminators = nn.ModuleList([
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
for i in range(len(n_ffts))
])
@property
def num_discriminators(self):
return len(self.discriminators)
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
B, C, T = x.shape
return x.view(-1, 1, T)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
================================================
FILE: audiocraft/adversarial/losses.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility module to handle adversarial losses without requiring to mess up the main training loop.
"""
import typing as tp
import flashy
import torch
import torch.nn as nn
import torch.nn.functional as F
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
class AdversarialLoss(nn.Module):
"""Adversary training wrapper.
Args:
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
where the first item is a list of logits and the second item is a list of feature maps.
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
loss (AdvLossType): Loss function for generator training.
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
loss_feat (FeatLossType): Feature matching loss function for generator training.
normalize (bool): Whether to normalize by number of sub-discriminators.
Example of usage:
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
for real in loader:
noise = torch.randn(...)
fake = model(noise)
adv_loss.train_adv(fake, real)
loss, _ = adv_loss(fake, real)
loss.backward()
"""
def __init__(self,
adversary: nn.Module,
optimizer: torch.optim.Optimizer,
loss: AdvLossType,
loss_real: AdvLossType,
loss_fake: AdvLossType,
loss_feat: tp.Optional[FeatLossType] = None,
normalize: bool = True):
super().__init__()
self.adversary: nn.Module = adversary
flashy.distrib.broadcast_model(self.adversary)
self.optimizer = optimizer
self.loss = loss
self.loss_real = loss_real
self.loss_fake = loss_fake
self.loss_feat = loss_feat
self.normalize = normalize
def _save_to_state_dict(self, destination, prefix, keep_vars):
# Add the optimizer state dict inside our own.
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
return destination
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Load optimizer state.
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_adversary_pred(self, x):
"""Run adversary model, validating expected output format."""
logits, fmaps = self.adversary(x)
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
f'Expecting a list of tensors as logits but {type(logits)} found.'
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
for fmap in fmaps:
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
return logits, fmaps
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
"""Train the adversary with the given fake and real example.
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
and call the optimizer.
"""
loss = torch.tensor(0., device=fake.device)
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
if self.normalize:
loss /= n_sub_adversaries
self.optimizer.zero_grad()
with flashy.distrib.eager_sync_model(self.adversary):
loss.backward()
self.optimizer.step()
return loss
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Return the loss for the generator, i.e. trying to fool the adversary,
and feature matching loss if provided.
"""
adv = torch.tensor(0., device=fake.device)
feat = torch.tensor(0., device=fake.device)
with flashy.utils.readonly(self.adversary):
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake in all_logits_fake_is_fake:
adv += self.loss(logit_fake_is_fake)
if self.loss_feat:
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
feat += self.loss_feat(fmap_fake, fmap_real)
if self.normalize:
adv /= n_sub_adversaries
feat /= n_sub_adversaries
return adv, feat
def get_adv_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_loss
elif loss_type == 'hinge':
return hinge_loss
elif loss_type == 'hinge2':
return hinge2_loss
raise ValueError('Unsupported loss')
def get_fake_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_fake_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_fake_loss
raise ValueError('Unsupported loss')
def get_real_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_real_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_real_loss
raise ValueError('Unsupported loss')
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def mse_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return -x.mean()
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0])
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
class FeatureMatchingLoss(nn.Module):
"""Feature matching loss for adversarial training.
Args:
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
normalize (bool): Whether to normalize the loss.
by number of feature maps.
"""
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
super().__init__()
self.loss = loss
self.normalize = normalize
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
n_fmaps = 0
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
assert feat_fake.shape == feat_real.shape
n_fmaps += 1
feat_loss += self.loss(feat_fake, feat_real)
feat_scale += torch.mean(torch.abs(feat_real))
if self.normalize:
feat_loss /= n_fmaps
return feat_loss
================================================
FILE: audiocraft/data/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Audio loading and writing support. Datasets for raw audio
or also including some metadata."""
# flake8: noqa
from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
================================================
FILE: audiocraft/data/audio.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Audio IO methods are defined in this module (info, read, write),
We rely on av library for faster read when possible, otherwise on torchaudio.
"""
from dataclasses import dataclass
from pathlib import Path
import logging
import typing as tp
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
import torchaudio as ta
import av
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
_av_initialized = False
def _init_av():
global _av_initialized
if _av_initialized:
return
logger = logging.getLogger('libav.mp3')
logger.setLevel(logging.ERROR)
_av_initialized = True
@dataclass(frozen=True)
class AudioFileInfo:
sample_rate: int
duration: float
channels: int
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
_init_av()
with av.open(str(filepath)) as af:
stream = af.streams.audio[0]
sample_rate = stream.codec_context.sample_rate
duration = float(stream.duration * stream.time_base)
channels = stream.channels
return AudioFileInfo(sample_rate, duration, channels)
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
info = soundfile.info(filepath)
return AudioFileInfo(info.samplerate, info.duration, info.channels)
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
# torchaudio no longer returns useful duration informations for some formats like mp3s.
filepath = Path(filepath)
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
# ffmpeg has some weird issue with flac.
return _soundfile_info(filepath)
else:
return _av_info(filepath)
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
"""FFMPEG-based audio file reading using PyAV bindings.
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
Args:
filepath (str or Path): Path to audio file to read.
seek_time (float): Time at which to start reading in the file.
duration (float): Duration to read from the file. If set to -1, the whole file is read.
Returns:
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
"""
_init_av()
with av.open(str(filepath)) as af:
stream = af.streams.audio[0]
sr = stream.codec_context.sample_rate
num_frames = int(sr * duration) if duration >= 0 else -1
frame_offset = int(sr * seek_time)
# we need a small negative offset otherwise we get some edge artifact
# from the mp3 decoder.
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
frames = []
length = 0
for frame in af.decode(streams=stream.index):
current_offset = int(frame.rate * frame.pts * frame.time_base)
strip = max(0, frame_offset - current_offset)
buf = torch.from_numpy(frame.to_ndarray())
if buf.shape[0] != stream.channels:
buf = buf.view(-1, stream.channels).t()
buf = buf[:, strip:]
frames.append(buf)
length += buf.shape[1]
if num_frames > 0 and length >= num_frames:
break
assert frames
# If the above assert fails, it is likely because we seeked past the end of file point,
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
# This will need proper debugging, in due time.
wav = torch.cat(frames, dim=1)
assert wav.shape[0] == stream.channels
if num_frames > 0:
wav = wav[:, :num_frames]
return f32_pcm(wav), sr
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
"""Read audio by picking the most appropriate backend tool based on the audio format.
Args:
filepath (str or Path): Path to audio file to read.
seek_time (float): Time at which to start reading in the file.
duration (float): Duration to read from the file. If set to -1, the whole file is read.
pad (bool): Pad output audio if not reaching expected duration.
Returns:
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
"""
fp = Path(filepath)
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
# There is some bug with ffmpeg and reading flac
info = _soundfile_info(filepath)
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
frame_offset = int(seek_time * info.sample_rate)
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
elif (
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
and duration <= 0 and seek_time == 0
):
# Torchaudio is faster if we load an entire file at once.
wav, sr = ta.load(fp)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
expected_frames = int(duration * sr)
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
return wav, sr
def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
add_suffix: bool = True) -> Path:
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
Args:
stem_name (str or Path): Filename without extension which will be added automatically.
format (str): Either "wav" or "mp3".
mp3_rate (int): kbps when using mp3s.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
make_parent_dir (bool): Make parent directory if it doesn't exist.
Returns:
Path: Path of the saved audio.
"""
assert wav.dtype.is_floating_point, "wav is not floating point"
if wav.dim() == 1:
wav = wav[None]
elif wav.dim() > 2:
raise ValueError("Input wav should be at most 2 dimension.")
assert wav.isfinite().all()
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
kwargs.update({"compression": mp3_rate})
elif format == 'wav':
wav = i16_pcm(wav)
suffix = '.wav'
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
suffix = ''
path = Path(str(stem_name) + suffix)
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
ta.save(path, wav, sample_rate, **kwargs)
except Exception:
if path.exists():
# we do not want to leave half written files around.
path.unlink()
raise
return path
================================================
FILE: audiocraft/data/audio_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""AudioDataset support. In order to handle a larger number of files
without having to scan again the folders, we precompute some metadata
(filename, sample rate, duration), and use that to efficiently sample audio segments.
"""
import argparse
import copy
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, fields
from contextlib import ExitStack
from functools import lru_cache
import gzip
import json
import logging
import os
from pathlib import Path
import random
import sys
import typing as tp
import torch
import torch.nn.functional as F
from .audio import audio_read, audio_info
from .audio_utils import convert_audio
from .zip import PathInZip
try:
import dora
except ImportError:
dora = None # type: ignore
@dataclass(order=True)
class BaseInfo:
@classmethod
def _dict2fields(cls, dictionary: dict):
return {
field.name: dictionary[field.name]
for field in fields(cls) if field.name in dictionary
}
@classmethod
def from_dict(cls, dictionary: dict):
_dictionary = cls._dict2fields(dictionary)
return cls(**_dictionary)
def to_dict(self):
return {
field.name: self.__getattribute__(field.name)
for field in fields(self)
}
@dataclass(order=True)
class AudioMeta(BaseInfo):
path: str
duration: float
sample_rate: int
amplitude: tp.Optional[float] = None
weight: tp.Optional[float] = None
# info_path is used to load additional information about the audio file that is stored in zip files.
info_path: tp.Optional[PathInZip] = None
@classmethod
def from_dict(cls, dictionary: dict):
base = cls._dict2fields(dictionary)
if 'info_path' in base and base['info_path'] is not None:
base['info_path'] = PathInZip(base['info_path'])
return cls(**base)
def to_dict(self):
d = super().to_dict()
if d['info_path'] is not None:
d['info_path'] = str(d['info_path'])
return d
@dataclass(order=True)
class SegmentInfo(BaseInfo):
meta: AudioMeta
seek_time: float
# The following values are given once the audio is processed, e.g.
# at the target sample rate and target number of channels.
n_frames: int # actual number of frames without padding
total_frames: int # total number of frames, padding included
sample_rate: int # actual sample rate
channels: int # number of audio channels.
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
logger = logging.getLogger(__name__)
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
"""AudioMeta from a path to an audio file.
Args:
file_path (str): Resolved path of valid audio file.
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
Returns:
AudioMeta: Audio file path and its metadata.
"""
info = audio_info(file_path)
amplitude: tp.Optional[float] = None
if not minimal:
wav, sr = audio_read(file_path)
amplitude = wav.abs().max().item()
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
"""If Dora is available as a dependency, try to resolve potential relative paths
in list of AudioMeta. This method is expected to be used when loading meta from file.
Args:
m (AudioMeta): Audio meta to resolve.
fast (bool): If True, uses a really fast check for determining if a file
is already absolute or not. Only valid on Linux/Mac.
Returns:
AudioMeta: Audio meta with resolved path.
"""
def is_abs(m):
if fast:
return str(m)[0] == '/'
else:
os.path.isabs(str(m))
if not dora:
return m
if not is_abs(m.path):
m.path = dora.git_save.to_absolute_path(m.path)
if m.info_path is not None and not is_abs(m.info_path.zip_path):
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
return m
def find_audio_files(path: tp.Union[Path, str],
exts: tp.List[str] = DEFAULT_EXTS,
resolve: bool = True,
minimal: bool = True,
progress: bool = False,
workers: int = 0) -> tp.List[AudioMeta]:
"""Build a list of AudioMeta from a given path,
collecting relevant audio files and fetching meta info.
Args:
path (str or Path): Path to folder containing audio files.
exts (list of str): List of file extensions to consider for audio files.
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
progress (bool): Whether to log progress on audio files collection.
workers (int): number of parallel workers, if 0, use only the current thread.
Returns:
list of AudioMeta: List of audio file path and its metadata.
"""
audio_files = []
futures: tp.List[Future] = []
pool: tp.Optional[ThreadPoolExecutor] = None
with ExitStack() as stack:
if workers > 0:
pool = ThreadPoolExecutor(workers)
stack.enter_context(pool)
if progress:
print("Finding audio files...")
for root, folders, files in os.walk(path, followlinks=True):
for file in files:
full_path = Path(root) / file
if full_path.suffix.lower() in exts:
audio_files.append(full_path)
if pool is not None:
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
if progress:
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
if progress:
print("Getting audio metadata...")
meta: tp.List[AudioMeta] = []
for idx, file_path in enumerate(audio_files):
try:
if pool is None:
m = _get_audio_meta(str(file_path), minimal)
else:
m = futures[idx].result()
if resolve:
m = _resolve_audio_meta(m)
except Exception as err:
print("Error with", str(file_path), err, file=sys.stderr)
continue
meta.append(m)
if progress:
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
meta.sort()
return meta
def load_audio_meta(path: tp.Union[str, Path],
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
"""Load list of AudioMeta from an optionally compressed json file.
Args:
path (str or Path): Path to JSON file.
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
fast (bool): activates some tricks to make things faster.
Returns:
list of AudioMeta: List of audio file path and its total duration.
"""
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
with open_fn(path, 'rb') as fp: # type: ignore
lines = fp.readlines()
meta = []
for line in lines:
d = json.loads(line)
m = AudioMeta.from_dict(d)
if resolve:
m = _resolve_audio_meta(m, fast=fast)
meta.append(m)
return meta
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
"""Save the audio metadata to the file pointer as json.
Args:
path (str or Path): Path to JSON file.
metadata (list of BaseAudioMeta): List of audio meta to save.
"""
Path(path).parent.mkdir(exist_ok=True, parents=True)
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
with open_fn(path, 'wb') as fp: # type: ignore
for m in meta:
json_str = json.dumps(m.to_dict()) + '\n'
json_bytes = json_str.encode('utf-8')
fp.write(json_bytes)
class AudioDataset:
"""Base audio dataset.
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
and potentially additional information, by creating random segments from the list of audio
files referenced in the metadata and applying minimal data pre-processing such as resampling,
mixing of channels, padding, etc.
If no segment_duration value is provided, the AudioDataset will return the full wav for each
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
duration, applying padding if required.
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
original audio meta.
Note that you can call `start_epoch(epoch)` in order to get
a deterministic "randomization" for `shuffle=True`.
For a given epoch and dataset index, this will always return the same extract.
You can get back some diversity by setting the `shuffle_seed` param.
Args:
meta (list of AudioMeta): List of audio files metadata.
segment_duration (float, optional): Optional segment duration of audio to load.
If not specified, the dataset will load the full audio segment from the file.
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
sample_rate (int): Target sample rate of the loaded audio samples.
channels (int): Target number of channels of the loaded audio samples.
sample_on_duration (bool): Set to `True` to sample segments with probability
dependent on audio file duration. This is only used if `segment_duration` is provided.
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
of the file duration and file weight. This is only used if `segment_duration` is provided.
min_segment_ratio (float): Minimum segment ratio to use when the audio file
is shorter than the desired segment.
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
audio shorter than this will be filtered out.
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
audio longer than this will be filtered out.
shuffle_seed (int): can be used to further randomize
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
with the expected segment_duration (which must be provided if load_wav is False).
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
are False. Will ensure a permutation on files when going through the dataset.
In that case the epoch number must be provided in order for the model
to continue the permutation across epochs. In that case, it is assumed
that `num_samples = total_batch_size * num_updates_per_epoch`, with
`total_batch_size` the overall batch size accounting for all gpus.
"""
def __init__(self,
meta: tp.List[AudioMeta],
segment_duration: tp.Optional[float] = None,
shuffle: bool = True,
num_samples: int = 10_000,
sample_rate: int = 48_000,
channels: int = 2,
pad: bool = True,
sample_on_duration: bool = True,
sample_on_weight: bool = True,
min_segment_ratio: float = 0.5,
max_read_retry: int = 10,
return_info: bool = False,
min_audio_duration: tp.Optional[float] = None,
max_audio_duration: tp.Optional[float] = None,
shuffle_seed: int = 0,
load_wav: bool = True,
permutation_on_files: bool = False,
):
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
assert segment_duration is None or segment_duration > 0
assert segment_duration is None or min_segment_ratio >= 0
self.segment_duration = segment_duration
self.min_segment_ratio = min_segment_ratio
self.max_audio_duration = max_audio_duration
self.min_audio_duration = min_audio_duration
if self.min_audio_duration is not None and self.max_audio_duration is not None:
assert self.min_audio_duration <= self.max_audio_duration
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
assert len(self.meta) # Fail fast if all data has been filtered.
self.total_duration = sum(d.duration for d in self.meta)
if segment_duration is None:
num_samples = len(self.meta)
self.num_samples = num_samples
self.shuffle = shuffle
self.sample_rate = sample_rate
self.channels = channels
self.pad = pad
self.sample_on_weight = sample_on_weight
self.sample_on_duration = sample_on_duration
self.sampling_probabilities = self._get_sampling_probabilities()
self.max_read_retry = max_read_retry
self.return_info = return_info
self.shuffle_seed = shuffle_seed
self.current_epoch: tp.Optional[int] = None
self.load_wav = load_wav
if not load_wav:
assert segment_duration is not None
self.permutation_on_files = permutation_on_files
if permutation_on_files:
assert not self.sample_on_duration
assert not self.sample_on_weight
assert self.shuffle
def start_epoch(self, epoch: int):
self.current_epoch = epoch
def __len__(self):
return self.num_samples
def _get_sampling_probabilities(self, normalized: bool = True):
"""Return the sampling probabilities for each file inside `self.meta`."""
scores: tp.List[float] = []
for file_meta in self.meta:
score = 1.
if self.sample_on_weight and file_meta.weight is not None:
score *= file_meta.weight
if self.sample_on_duration:
score *= file_meta.duration
scores.append(score)
probabilities = torch.tensor(scores)
if normalized:
probabilities /= probabilities.sum()
return probabilities
@staticmethod
@lru_cache(16)
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
# Used to keep the most recent files permutation in memory implicitely.
# will work unless someone is using a lot of Datasets in parallel.
rng = torch.Generator()
rng.manual_seed(base_seed + permutation_index)
return torch.randperm(num_files, generator=rng)
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
This is only called if `segment_duration` is not None.
You must use the provided random number generator `rng` for reproducibility.
You can further make use of the index accessed.
"""
if self.permutation_on_files:
assert self.current_epoch is not None
total_index = self.current_epoch * len(self) + index
permutation_index = total_index // len(self.meta)
relative_index = total_index % len(self.meta)
permutation = AudioDataset._get_file_permutation(
len(self.meta), permutation_index, self.shuffle_seed)
file_index = permutation[relative_index]
return self.meta[file_index]
if not self.sample_on_weight and not self.sample_on_duration:
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
else:
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
return self.meta[file_index]
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
# Override this method in subclass if needed.
if self.load_wav:
return audio_read(path, seek_time, duration, pad=False)
else:
assert self.segment_duration is not None
n_frames = int(self.sample_rate * self.segment_duration)
return torch.zeros(self.channels, n_frames), self.sample_rate
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
if self.segment_duration is None:
file_meta = self.meta[index]
out, sr = audio_read(file_meta.path)
out = convert_audio(out, sr, self.sample_rate, self.channels)
n_frames = out.shape[-1]
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
sample_rate=self.sample_rate, channels=out.shape[0])
else:
rng = torch.Generator()
if self.shuffle:
# We use index, plus extra randomness, either totally random if we don't know the epoch.
# otherwise we make use of the epoch number and optional shuffle_seed.
if self.current_epoch is None:
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
else:
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
else:
# We only use index
rng.manual_seed(index)
for retry in range(self.max_read_retry):
file_meta = self.sample_file(index, rng)
# We add some variance in the file position even if audio file is smaller than segment
# without ending up with empty segments
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
seek_time = torch.rand(1, generator=rng).item() * max_seek
try:
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
out = convert_audio(out, sr, self.sample_rate, self.channels)
n_frames = out.shape[-1]
target_frames = int(self.segment_duration * self.sample_rate)
if self.pad:
out = F.pad(out, (0, target_frames - n_frames))
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
sample_rate=self.sample_rate, channels=out.shape[0])
except Exception as exc:
logger.warning("Error opening file %s: %r", file_meta.path, exc)
if retry == self.max_read_retry - 1:
raise
else:
break
if self.return_info:
# Returns the wav and additional information on the wave segment
return out, segment_info
else:
return out
def collater(self, samples):
"""The collater function has to be provided to the dataloader
if AudioDataset has return_info=True in order to properly collate
the samples of a batch.
"""
if self.segment_duration is None and len(samples) > 1:
assert self.pad, "Must allow padding when batching examples of different durations."
# In this case the audio reaching the collater is of variable length as segment_duration=None.
to_pad = self.segment_duration is None and self.pad
if to_pad:
max_len = max([wav.shape[-1] for wav, _ in samples])
def _pad_wav(wav):
return F.pad(wav, (0, max_len - wav.shape[-1]))
if self.return_info:
if len(samples) > 0:
assert len(samples[0]) == 2
assert isinstance(samples[0][0], torch.Tensor)
assert isinstance(samples[0][1], SegmentInfo)
wavs = [wav for wav, _ in samples]
segment_infos = [copy.deepcopy(info) for _, info in samples]
if to_pad:
# Each wav could be of a different duration as they are not segmented.
for i in range(len(samples)):
# Determines the total length of the signal with padding, so we update here as we pad.
segment_infos[i].total_frames = max_len
wavs[i] = _pad_wav(wavs[i])
wav = torch.stack(wavs)
return wav, segment_infos
else:
assert isinstance(samples[0], torch.Tensor)
if to_pad:
samples = [_pad_wav(s) for s in samples]
return torch.stack(samples)
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
orig_len = len(meta)
# Filter data that is too short.
if self.min_audio_duration is not None:
meta = [m for m in meta if m.duration >= self.min_audio_duration]
# Filter data that is too long.
if self.max_audio_duration is not None:
meta = [m for m in meta if m.duration <= self.max_audio_duration]
filtered_len = len(meta)
removed_percentage = 100*(1-float(filtered_len)/orig_len)
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
if removed_percentage < 10:
logging.debug(msg)
else:
logging.warning(msg)
return meta
@classmethod
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
Args:
root (str or Path): Path to root folder containing audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
if root.is_dir():
if (root / 'data.jsonl').exists():
root = root / 'data.jsonl'
elif (root / 'data.jsonl.gz').exists():
root = root / 'data.jsonl.gz'
else:
raise ValueError("Don't know where to read metadata from in the dir. "
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
meta = load_audio_meta(root)
return cls(meta, **kwargs)
@classmethod
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
Args:
root (str or Path): Path to root folder containing audio files.
minimal_meta (bool): Whether to only load minimal metadata or not.
exts (list of str): Extensions for audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
if root.is_file():
meta = load_audio_meta(root, resolve=True)
else:
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
return cls(meta, **kwargs)
def main():
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
parser = argparse.ArgumentParser(
prog='audio_dataset',
description='Generate .jsonl files by scanning a folder.')
parser.add_argument('root', help='Root folder with all the audio files')
parser.add_argument('output_meta_file',
help='Output file to store the metadata, ')
parser.add_argument('--complete',
action='store_false', dest='minimal', default=True,
help='Retrieve all metadata, even the one that are expansive '
'to compute (e.g. normalization).')
parser.add_argument('--resolve',
action='store_true', default=False,
help='Resolve the paths to be absolute and with no symlinks.')
parser.add_argument('--workers',
default=10, type=int,
help='Number of workers.')
args = parser.parse_args()
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
save_audio_meta(args.output_meta_file, meta)
if __name__ == '__main__':
main()
================================================
FILE: audiocraft/data/audio_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Various utilities for audio convertion (pcm format, sample rate and channels),
and volume normalization."""
import sys
import typing as tp
import julius
import torch
import torchaudio
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
"""Convert audio to the given number of channels.
Args:
wav (torch.Tensor): Audio wave of shape [B, C, T].
channels (int): Expected number of channels as output.
Returns:
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
"""
*shape, src_channels, length = wav.shape
if src_channels == channels:
pass
elif channels == 1:
# Case 1:
# The caller asked 1-channel audio, and the stream has multiple
# channels, downmix all channels.
wav = wav.mean(dim=-2, keepdim=True)
elif src_channels == 1:
# Case 2:
# The caller asked for multiple channels, but the input file has
# a single channel, replicate the audio over all channels.
wav = wav.expand(*shape, channels, length)
elif src_channels >= channels:
# Case 3:
# The caller asked for multiple channels, and the input file has
# more channels than requested. In that case return the first channels.
wav = wav[..., :channels, :]
else:
# Case 4: What is a reasonable choice here?
raise ValueError('The audio file has less channels than requested but is not mono.')
return wav
def convert_audio(wav: torch.Tensor, from_rate: float,
to_rate: float, to_channels: int) -> torch.Tensor:
"""Convert audio to new sample rate and number of audio channels."""
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
wav = convert_audio_channels(wav, to_channels)
return wav
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, energy_floor: float = 2e-3):
"""Normalize an input signal to a user loudness in dB LKFS.
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
Args:
wav (torch.Tensor): Input multichannel audio data.
sample_rate (int): Sample rate.
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
loudness_compressor (bool): Uses tanh for soft clipping.
energy_floor (float): anything below that RMS level will not be rescaled.
Returns:
torch.Tensor: Loudness normalized output data.
"""
energy = wav.pow(2).mean().sqrt().item()
if energy < energy_floor:
return wav
transform = torchaudio.transforms.Loudness(sample_rate)
input_loudness_db = transform(wav).item()
# calculate the gain needed to scale to the desired loudness level
delta_loudness = -loudness_headroom_db - input_loudness_db
gain = 10.0 ** (delta_loudness / 20.0)
output = gain * wav
if loudness_compressor:
output = torch.tanh(output)
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
return output
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
"""Utility function to clip the audio with logging if specified."""
max_scale = wav.abs().max()
if log_clipping and max_scale > 1:
clamp_prob = (wav.abs() > 1).float().mean().item()
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
#wav.clamp_(-1, 1)
wav = wav.clone().clamp_(-1, 1)
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, log_clipping: bool = False,
sample_rate: tp.Optional[int] = None,
stem_name: tp.Optional[str] = None) -> torch.Tensor:
"""Normalize the audio according to the prescribed strategy (see after).
Args:
wav (torch.Tensor): Audio data.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): If True, uses tanh based soft clipping.
log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
sample_rate (int): Sample rate for the audio data (required for loudness).
stem_name (str, optional): Stem name for clipping logging.
Returns:
torch.Tensor: Normalized audio.
"""
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
scale_rms = 10 ** (-rms_headroom_db / 20)
if strategy == 'peak':
rescaling = (scale_peak / wav.abs().max())
if normalize or rescaling < 1:
wav = wav * rescaling
elif strategy == 'clip':
wav = wav.clamp(-scale_peak, scale_peak)
elif strategy == 'rms':
mono = wav.mean(dim=0)
rescaling = scale_rms / mono.pow(2).mean().sqrt()
if normalize or rescaling < 1:
wav = wav * rescaling
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
elif strategy == 'loudness':
assert sample_rate is not None, "Loudness normalization requires sample rate."
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
else:
assert wav.abs().max() < 1
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
return wav
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format.
"""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / 2**15
elif wav.dtype == torch.int32:
return wav.float() / 2**31
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to int 16 bits PCM format.
..Warning:: There exist many formula for doing this conversion. None are perfect
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
it is possible that `i16_pcm(f32_pcm)) != Identity`.
"""
if wav.dtype.is_floating_point:
assert wav.abs().max() <= 1
candidate = (wav * 2 ** 15).round()
if candidate.max() >= 2 ** 15: # clipping would occur
candidate = (wav * (2 ** 15 - 1)).round()
return candidate.short()
else:
assert wav.dtype == torch.int16
return wav
================================================
FILE: audiocraft/data/info_audio_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base classes for the datasets that also provide non-audio metadata,
e.g. description, text transcription etc.
"""
from dataclasses import dataclass
import logging
import math
import re
import typing as tp
import torch
from .audio_dataset import AudioDataset, AudioMeta
from ..environment import AudioCraftEnvironment
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
logger = logging.getLogger(__name__)
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
"""Monkey-patch meta to match cluster specificities."""
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
if meta.info_path is not None:
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
return meta
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
"""Monkey-patch all meta to match cluster specificities."""
return [_clusterify_meta(m) for m in meta]
@dataclass
class AudioInfo(SegmentWithAttributes):
"""Dummy SegmentInfo with empty attributes.
The InfoAudioDataset is expected to return metadata that inherits
from SegmentWithAttributes class and can return conditioning attributes.
This basically guarantees all datasets will be compatible with current
solver that contain conditioners requiring this.
"""
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
def to_condition_attributes(self) -> ConditioningAttributes:
return ConditioningAttributes()
class InfoAudioDataset(AudioDataset):
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
"""
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
super().__init__(clusterify_all_meta(meta), **kwargs)
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
if not self.return_info:
wav = super().__getitem__(index)
assert isinstance(wav, torch.Tensor)
return wav
wav, meta = super().__getitem__(index)
return wav, AudioInfo(**meta.to_dict())
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
"""Preprocess a single keyword or possible a list of keywords."""
if isinstance(value, list):
return get_keyword_list(value)
else:
return get_keyword(value)
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess a single keyword."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
else:
return value.strip()
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess a single keyword."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
else:
return value.strip().lower()
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
"""Preprocess a list of keywords."""
if isinstance(values, str):
values = [v.strip() for v in re.split(r'[,\s]', values)]
elif isinstance(values, float) and math.isnan(values):
values = []
if not isinstance(values, list):
logger.debug(f"Unexpected keyword list {values}")
values = [str(values)]
kws = [get_keyword(v) for v in values]
kw_list = [k for k in kws if k is not None]
if len(kw_list) == 0:
return None
else:
return kw_list
================================================
FILE: audiocraft/data/music_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset of music tracks with rich metadata.
"""
from dataclasses import dataclass, field, fields, replace
import gzip
import json
import logging
from pathlib import Path
import random
import typing as tp
import torch
from .info_audio_dataset import (
InfoAudioDataset,
AudioInfo,
get_keyword_list,
get_keyword,
get_string
)
from ..modules.conditioners import (
ConditioningAttributes,
JointEmbedCondition,
WavCondition,
)
from ..utils.utils import warn_once
logger = logging.getLogger(__name__)
@dataclass
class MusicInfo(AudioInfo):
"""Segment info augmented with music metadata.
"""
# music-specific metadata
title: tp.Optional[str] = None
artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
key: tp.Optional[str] = None
bpm: tp.Optional[float] = None
genre: tp.Optional[str] = None
moods: tp.Optional[list] = None
keywords: tp.Optional[list] = None
description: tp.Optional[str] = None
name: tp.Optional[str] = None
instrument: tp.Optional[str] = None
# original wav accompanying the metadata
self_wav: tp.Optional[WavCondition] = None
# dict mapping attributes names to tuple of wav, text and metadata
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
@property
def has_music_meta(self) -> bool:
return self.name is not None
def to_condition_attributes(self) -> ConditioningAttributes:
out = ConditioningAttributes()
for _field in fields(self):
key, value = _field.name, getattr(self, _field.name)
if key == 'self_wav':
out.wav[key] = value
elif key == 'joint_embed':
for embed_attribute, embed_cond in value.items():
out.joint_embed[embed_attribute] = embed_cond
else:
if isinstance(value, list):
value = ' '.join(value)
out.text[key] = value
return out
@staticmethod
def attribute_getter(attribute):
if attribute == 'bpm':
preprocess_func = get_bpm
elif attribute == 'key':
preprocess_func = get_musical_key
elif attribute in ['moods', 'keywords']:
preprocess_func = get_keyword_list
elif attribute in ['genre', 'name', 'instrument']:
preprocess_func = get_keyword
elif attribute in ['title', 'artist', 'description']:
preprocess_func = get_string
else:
preprocess_func = None
return preprocess_func
@classmethod
def from_dict(cls, dictionary: dict, fields_required: bool = False):
_dictionary: tp.Dict[str, tp.Any] = {}
# allow a subset of attributes to not be loaded from the dictionary
# these attributes may be populated later
post_init_attributes = ['self_wav', 'joint_embed']
optional_fields = ['keywords']
for _field in fields(cls):
if _field.name in post_init_attributes:
continue
elif _field.name not in dictionary:
if fields_required and _field.name not in optional_fields:
raise KeyError(f"Unexpected missing key: {_field.name}")
else:
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
value = dictionary[_field.name]
if preprocess_func:
value = preprocess_func(value)
_dictionary[_field.name] = value
return cls(**_dictionary)
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
"""Augment MusicInfo description with additional metadata fields and potential dropout.
Additional textual attributes are added given probability 'merge_text_conditions_p' and
the original textual description is dropped from the augmented description given probability drop_desc_p.
Args:
music_info (MusicInfo): The music metadata to augment.
merge_text_p (float): Probability of merging additional metadata to the description.
If provided value is 0, then no merging is performed.
drop_desc_p (float): Probability of dropping the original description on text merge.
if provided value is 0, then no drop out is performed.
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
Returns:
MusicInfo: The MusicInfo with augmented textual description.
"""
def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
keep_field = random.uniform(0, 1) < drop_other_p
return valid_field_name and valid_field_value and keep_field
def process_value(v: tp.Any) -> str:
if isinstance(v, (int, float, str)):
return str(v)
if isinstance(v, list):
return ", ".join(v)
else:
raise ValueError(f"Unknown type for text value! ({type(v), v})")
description = music_info.description
metadata_text = ""
if random.uniform(0, 1) < merge_text_p:
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
random.shuffle(meta_pairs)
metadata_text = ". ".join(meta_pairs)
description = description if not random.uniform(0, 1) < drop_desc_p else None
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
if description is None:
description = metadata_text if len(metadata_text) > 1 else None
else:
description = ". ".join([description.rstrip('.'), metadata_text])
description = description.strip() if description else None
music_info = replace(music_info)
music_info.description = description
return music_info
class Paraphraser:
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
self.paraphrase_p = paraphrase_p
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
with open_fn(paraphrase_source, 'rb') as f: # type: ignore
self.paraphrase_source = json.loads(f.read())
logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
def sample_paraphrase(self, audio_path: str, description: str):
if random.random() >= self.paraphrase_p:
return description
info_path = Path(audio_path).with_suffix('.json')
if info_path not in self.paraphrase_source:
warn_once(logger, f"{info_path} not in paraphrase source!")
return description
new_desc = random.choice(self.paraphrase_source[info_path])
logger.debug(f"{description} -> {new_desc}")
return new_desc
class MusicDataset(InfoAudioDataset):
"""Music dataset is an AudioDataset with music-related metadata.
Args:
info_fields_required (bool): Whether to enforce having required fields.
merge_text_p (float): Probability of merging additional metadata to the description.
drop_desc_p (float): Probability of dropping the original description on text merge.
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
paraphrases for the description. The json should be a dict with keys are the
original info path (e.g. track_path.json) and each value is a list of possible
paraphrased.
paraphrase_p (float): probability of taking a paraphrase.
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
"""
def __init__(self, *args, info_fields_required: bool = True,
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
joint_embed_attributes: tp.List[str] = [],
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
**kwargs):
kwargs['return_info'] = True # We require the info for each song of the dataset.
super().__init__(*args, **kwargs)
self.info_fields_required = info_fields_required
self.merge_text_p = merge_text_p
self.drop_desc_p = drop_desc_p
self.drop_other_p = drop_other_p
self.joint_embed_attributes = joint_embed_attributes
self.paraphraser = None
if paraphrase_source is not None:
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
def __getitem__(self, index):
wav, info = super().__getitem__(index)
info_data = info.to_dict()
music_info_path = Path(info.meta.path).with_suffix('.json')
if Path(music_info_path).exists():
with open(music_info_path, 'r') as json_file:
music_data = json.load(json_file)
music_data.update(info_data)
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
if self.paraphraser is not None:
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
if self.merge_text_p:
music_info = augment_music_info_description(
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
else:
music_info = MusicInfo.from_dict(info_data, fields_required=False)
music_info.self_wav = WavCondition(
wav=wav[None], length=torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
for att in self.joint_embed_attributes:
att_value = getattr(music_info, att)
joint_embed_cond = JointEmbedCondition(
wav[None], [att_value], torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
music_info.joint_embed[att] = joint_embed_cond
return wav, music_info
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess key keywords, discarding them if there are multiple key defined."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
elif ',' in value:
# For now, we discard when multiple keys are defined separated with comas
return None
else:
return value.strip().lower()
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
"""Preprocess to a float."""
if value is None:
return None
try:
return float(value)
except ValueError:
return None
================================================
FILE: audiocraft/data/sound_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset of audio with a simple description.
"""
from dataclasses import dataclass, fields, replace
import json
from pathlib import Path
import random
import typing as tp
import numpy as np
import torch
from .info_audio_dataset import (
InfoAudioDataset,
get_keyword_or_keyword_list
)
from ..modules.conditioners import (
ConditioningAttributes,
SegmentWithAttributes,
WavCondition,
)
EPS = torch.finfo(torch.float32).eps
TARGET_LEVEL_LOWER = -35
TARGET_LEVEL_UPPER = -15
@dataclass
class SoundInfo(SegmentWithAttributes):
"""Segment info augmented with Sound metadata.
"""
description: tp.Optional[str] = None
self_wav: tp.Optional[torch.Tensor] = None
@property
def has_sound_meta(self) -> bool:
return self.description is not None
def to_condition_attributes(self) -> ConditioningAttributes:
out = ConditioningAttributes()
for _field in fields(self):
key, value = _field.name, getattr(self, _field.name)
if key == 'self_wav':
out.wav[key] = value
else:
out.text[key] = value
return out
@staticmethod
def attribute_getter(attribute):
if attribute == 'description':
preprocess_func = get_keyword_or_keyword_list
else:
preprocess_func = None
return preprocess_func
@classmethod
def from_dict(cls, dictionary: dict, fields_required: bool = False):
_dictionary: tp.Dict[str, tp.Any] = {}
# allow a subset of attributes to not be loaded from the dictionary
# these attributes may be populated later
post_init_attributes = ['self_wav']
for _field in fields(cls):
if _field.name in post_init_attributes:
continue
elif _field.name not in dictionary:
if fields_required:
raise KeyError(f"Unexpected missing key: {_field.name}")
else:
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
value = dictionary[_field.name]
if preprocess_func:
value = preprocess_func(value)
_dictionary[_field.name] = value
return cls(**_dictionary)
class SoundDataset(InfoAudioDataset):
"""Sound audio dataset: Audio dataset with environmental sound-specific metadata.
Args:
info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
The metadata files contained in this folder are expected to match the stem of the audio file with
a json extension.
aug_p (float): Probability of performing audio mixing augmentation on the batch.
mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
kwargs: Additional arguments for AudioDataset.
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
"""
def __init__(
self,
*args,
info_fields_required: bool = True,
external_metadata_source: tp.Optional[str] = None,
aug_p: float = 0.,
mix_p: float = 0.,
mix_snr_low: int = -5,
mix_snr_high: int = 5,
mix_min_overlap: float = 0.5,
**kwargs
):
kwargs['return_info'] = True # We require the info for each song of the dataset.
super().__init__(*args, **kwargs)
self.info_fields_required = info_fields_required
self.external_metadata_source = external_metadata_source
self.aug_p = aug_p
self.mix_p = mix_p
if self.aug_p > 0:
assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
self.mix_snr_low = mix_snr_low
self.mix_snr_high = mix_snr_high
self.mix_min_overlap = mix_min_overlap
def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
"""Get path of JSON with metadata (description, etc.).
If there exists a JSON with the same name as 'path.name', then it will be used.
Else, such JSON will be searched for in an external json source folder if it exists.
"""
info_path = Path(path).with_suffix('.json')
if Path(info_path).exists():
return info_path
elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
return Path(self.external_metadata_source) / info_path.name
else:
raise Exception(f"Unable to find a metadata JSON for path: {path}")
def __getitem__(self, index):
wav, info = super().__getitem__(index)
info_data = info.to_dict()
info_path = self._get_info_path(info.meta.path)
if Path(info_path).exists():
with open(info_path, 'r') as json_file:
sound_data = json.load(json_file)
sound_data.update(info_data)
sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
# if there are multiple descriptions, sample one randomly
if isinstance(sound_info.description, list):
sound_info.description = random.choice(sound_info.description)
else:
sound_info = SoundInfo.from_dict(info_data, fields_required=False)
sound_info.self_wav = WavCondition(
wav=wav[None], length=torch.tensor([info.n_frames]),
sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
return wav, sound_info
def collater(self, samples):
# when training, audio mixing is performed in the collate function
wav, sound_info = super().collater(samples) # SoundDataset always returns infos
if self.aug_p > 0:
wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
min_overlap=self.mix_min_overlap)
return wav, sound_info
def rms_f(x: torch.Tensor) -> torch.Tensor:
return (x ** 2).mean(1).pow(0.5)
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
"""Normalize the signal to the target level."""
rms = rms_f(audio)
scalar = 10 ** (target_level / 20) / (rms + EPS)
audio = audio * scalar.unsqueeze(1)
return audio
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
return (abs(audio) > clipping_threshold).any(1)
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
remainder = src.shape[1] - start
if dst.shape[1] > remainder:
src[:, start:] = src[:, start:] + dst[:, :remainder]
else:
src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
return src
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
"""Function to mix clean speech and noise at various SNR levels.
Args:
clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
snr (int): SNR level when mixing.
min_overlap (float): Minimum overlap between the two mixed sources.
target_level (int): Gain level in dB.
clipping_threshold (float): Threshold for clipping the audio.
Returns:
torch.Tensor: The mixed audio, of shape [B, T].
"""
if clean.shape[1] > noise.shape[1]:
noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
else:
noise = noise[:, :clean.shape[1]]
# normalizing to -25 dB FS
clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
clean = normalize(clean, target_level)
rmsclean = rms_f(clean)
noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
noise = normalize(noise, target_level)
rmsnoise = rms_f(noise)
# set the noise level for a given SNR
noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
noisenewlevel = noise * noisescalar
# mix noise and clean speech
noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
# randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
# there is a chance of clipping that might happen with very less probability, which is not a major issue.
noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
rmsnoisy = rms_f(noisyspeech)
scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
noisyspeech = noisyspeech * scalarnoisy
clean = clean * scalarnoisy
noisenewlevel = noisenewlevel * scalarnoisy
# final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
clipped = is_clipped(noisyspeech)
if clipped.any():
noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
return noisyspeech
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
if snr_low == snr_high:
snr = snr_low
else:
snr = np.random.randint(snr_low, snr_high)
mix = snr_mixer(src, dst, snr, min_overlap)
return mix
def mix_text(src_text: str, dst_text: str):
"""Mix text from different sources by concatenating them."""
if src_text == dst_text:
return src_text
return src_text + " " + dst_text
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
snr_low: int, snr_high: int, min_overlap: float):
"""Mix samples within a batch, summing the waveforms and concatenating the text infos.
Args:
wavs (torch.Tensor): Audio tensors of shape [B, C, T].
infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
aug_p (float): Augmentation probability.
mix_p (float): Proportion of items in the batch to mix (and merge) together.
snr_low (int): Lowerbound for sampling SNR.
snr_high (int): Upperbound for sampling SNR.
min_overlap (float): Minimum overlap between mixed samples.
Returns:
tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
and mixed SoundInfo for the given batch.
"""
# no mixing to perform within the batch
if mix_p == 0:
return wavs, infos
if random.uniform(0, 1) < aug_p:
# perform all augmentations on waveforms as [B, T]
# randomly picking pairs of audio to mix
assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
wavs = wavs.mean(dim=1, keepdim=False)
B, T = wavs.shape
k = int(mix_p * B)
mixed_sources_idx = torch.randperm(B)[:k]
mixed_targets_idx = torch.randperm(B)[:k]
aug_wavs = snr_mix(
wavs[mixed_sources_idx],
wavs[mixed_targets_idx],
snr_low,
snr_high,
min_overlap,
)
# mixing textual descriptions in metadata
descriptions = [info.description for info in infos]
aug_infos = []
for i, j in zip(mixed_sources_idx, mixed_targets_idx):
text = mix_text(descriptions[i], descriptions[j])
m = replace(infos[i])
m.description = text
aug_infos.append(m)
# back to [B, C, T]
aug_wavs = aug_wavs.unsqueeze(1)
assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
return aug_wavs, aug_infos # [B, C, T]
else:
# randomly pick samples in the batch to match
# the batch size when performing audio mixing
B, C, T = wavs.shape
k = int(mix_p * B)
wav_idx = torch.randperm(B)[:k]
wavs = wavs[wav_idx]
infos = [infos[i] for i in wav_idx]
assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
return wavs, infos # [B, C, T]
================================================
FILE: audiocraft/data/zip.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utility for reading some info from inside a zip file.
"""
import typing
import zipfile
from dataclasses import dataclass
from functools import lru_cache
from typing_extensions import Literal
DEFAULT_SIZE = 32
MODE = Literal['r', 'w', 'x', 'a']
@dataclass(order=True)
class PathInZip:
"""Hold a path of file within a zip file.
Args:
path (str): The convention is :.
Let's assume there is a zip file /some/location/foo.zip
and inside of it is a json file located at /data/file1.json,
Then we expect path = "/some/location/foo.zip:/data/file1.json".
"""
INFO_PATH_SEP = ':'
zip_path: str
file_path: str
def __init__(self, path: str) -> None:
split_path = path.split(self.INFO_PATH_SEP)
assert len(split_path) == 2
self.zip_path, self.file_path = split_path
@classmethod
def from_paths(cls, zip_path: str, file_path: str):
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
def __str__(self) -> str:
return self.zip_path + self.INFO_PATH_SEP + self.file_path
def _open_zip(path: str, mode: MODE = 'r'):
return zipfile.ZipFile(path, mode)
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
def set_zip_cache_size(max_size: int):
"""Sets the maximal LRU caching for zip file opening.
Args:
max_size (int): the maximal LRU cache.
"""
global _cached_open_zip
_cached_open_zip = lru_cache(max_size)(_open_zip)
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
"""Opens a file stored inside a zip and returns a file-like object.
Args:
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
mode (str): The mode in which to open the file with.
Returns:
A file-like object for PathInZip.
"""
zf = _cached_open_zip(path_in_zip.zip_path)
return zf.open(path_in_zip.file_path)
================================================
FILE: audiocraft/environment.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
"""
import logging
import os
from pathlib import Path
import re
import typing as tp
import omegaconf
from .utils.cluster import _guess_cluster_type
logger = logging.getLogger(__name__)
class AudioCraftEnvironment:
"""Environment configuration for teams and clusters.
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
Use the following environment variables to specify the cluster, team or configuration:
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
cannot be inferred automatically.
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
If not set, configuration is read from config/teams.yaml.
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
Cluster configuration are shared across teams to match compute allocation,
specify your cluster configuration in the configuration file under a key mapping
your team name.
"""
_instance = None
DEFAULT_TEAM = "default"
def __init__(self) -> None:
"""Loads configuration."""
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
cluster_type = _guess_cluster_type()
cluster = os.getenv(
"AUDIOCRAFT_CLUSTER", cluster_type.value
)
logger.info("Detecting cluster type %s", cluster_type)
self.cluster: str = cluster
config_path = os.getenv(
"AUDIOCRAFT_CONFIG",
Path(__file__)
.parent.parent.joinpath("config/teams", self.team)
.with_suffix(".yaml"),
)
self.config = omegaconf.OmegaConf.load(config_path)
self._dataset_mappers = []
cluster_config = self._get_cluster_config()
if "dataset_mappers" in cluster_config:
for pattern, repl in cluster_config["dataset_mappers"].items():
regex = re.compile(pattern)
self._dataset_mappers.append((regex, repl))
def _get_cluster_config(self) -> omegaconf.DictConfig:
assert isinstance(self.config, omegaconf.DictConfig)
return self.config[self.cluster]
@classmethod
def instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls):
"""Clears the environment and forces a reload on next invocation."""
cls._instance = None
@classmethod
def get_team(cls) -> str:
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
If not defined, defaults to "labs".
"""
return cls.instance().team
@classmethod
def get_cluster(cls) -> str:
"""Gets the detected cluster.
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
"""
return cls.instance().cluster
@classmethod
def get_dora_dir(cls) -> Path:
"""Gets the path to the dora directory for the current team and cluster.
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
"""
cluster_config = cls.instance()._get_cluster_config()
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
logger.warning(f"Dora directory: {dora_dir}")
return Path(dora_dir)
@classmethod
def get_reference_dir(cls) -> Path:
"""Gets the path to the reference directory for the current team and cluster.
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
"""
cluster_config = cls.instance()._get_cluster_config()
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
@classmethod
def get_slurm_exclude(cls) -> tp.Optional[str]:
"""Get the list of nodes to exclude for that cluster."""
cluster_config = cls.instance()._get_cluster_config()
return cluster_config.get("slurm_exclude")
@classmethod
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
Args:
partition_types (list[str], optional): partition types to retrieve. Values must be
from ['global', 'team']. If not provided, the global partition is returned.
"""
if not partition_types:
partition_types = ["global"]
cluster_config = cls.instance()._get_cluster_config()
partitions = [
cluster_config["partitions"][partition_type]
for partition_type in partition_types
]
return ",".join(partitions)
@classmethod
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
"""Converts reference placeholder in path with configured reference dir to resolve paths.
Args:
path (str or Path): Path to resolve.
Returns:
Path: Resolved path.
"""
path = str(path)
if path.startswith("//reference"):
reference_dir = cls.get_reference_dir()
logger.warn(f"Reference directory: {reference_dir}")
assert (
reference_dir.exists() and reference_dir.is_dir()
), f"Reference directory does not exist: {reference_dir}."
path = re.sub("^//reference", str(reference_dir), path)
return Path(path)
@classmethod
def apply_dataset_mappers(cls, path: str) -> str:
"""Applies dataset mapping regex rules as defined in the configuration.
If no rules are defined, the path is returned as-is.
"""
instance = cls.instance()
for pattern, repl in instance._dataset_mappers:
path = pattern.sub(repl, path)
return path
================================================
FILE: audiocraft/grids/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dora Grids."""
================================================
FILE: audiocraft/grids/_base_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import time
import typing as tp
from dora import Explorer
import treetable as tt
def get_sheep_ping(sheep) -> tp.Optional[str]:
"""Return the amount of time since the Sheep made some update
to its log. Returns a str using the relevant time unit."""
ping = None
if sheep.log is not None and sheep.log.exists():
delta = time.time() - sheep.log.stat().st_mtime
if delta > 3600 * 24:
ping = f'{delta / (3600 * 24):.1f}d'
elif delta > 3600:
ping = f'{delta / (3600):.1f}h'
elif delta > 60:
ping = f'{delta / 60:.1f}m'
else:
ping = f'{delta:.1f}s'
return ping
class BaseExplorer(ABC, Explorer):
"""Base explorer for AudioCraft grids.
All task specific solvers are expected to implement the `get_grid_metrics`
method to specify logic about metrics to display for a given task.
If additional stages are used, the child explorer must define how to handle
these new stages in the `process_history` and `process_sheep` methods.
"""
def stages(self):
return ["train", "valid", "evaluate"]
def get_grid_meta(self):
"""Returns the list of Meta information to display for each XP/job.
"""
return [
tt.leaf("index", align=">"),
tt.leaf("name", wrap=140),
tt.leaf("state"),
tt.leaf("sig", align=">"),
tt.leaf("sid", align="<"),
]
@abstractmethod
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table.
"""
...
def process_sheep(self, sheep, history):
train = {
"epoch": len(history),
}
parts = {"train": train}
for metrics in history:
for key, sub in metrics.items():
part = parts.get(key, {})
if 'duration' in sub:
# Convert to minutes for readability.
sub['duration'] = sub['duration'] / 60.
part.update(sub)
parts[key] = part
ping = get_sheep_ping(sheep)
if ping is not None:
for name in self.stages():
if name not in parts:
parts[name] = {}
# Add the ping to each part for convenience.
parts[name]['ping'] = ping
return parts
================================================
FILE: audiocraft/grids/audiogen/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""AudioGen grids."""
================================================
FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ..musicgen._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=64, partition=partitions)
launcher.bind_(solver='audiogen/audiogen_base_16khz')
# replace this by the desired environmental sound dataset
launcher.bind_(dset='internal/sounds_16khz')
fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
launcher.bind_(fsdp)
launcher(medium)
================================================
FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Evaluation with objective metrics for the pretrained AudioGen models.
This grid takes signature from the training grid and runs evaluation-only stage.
When running the grid for the first time, please use:
REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
Note that you need the proper metrics external libraries setup to use all
the objective metrics activated in this grid. Refer to the README for more information.
"""
import os
from ..musicgen._explorers import GenerationEvalExplorer
from ...environment import AudioCraftEnvironment
from ... import train
def eval(launcher, batch_size: int = 32):
opts = {
'dset': 'audio/audiocaps_16khz',
'solver/audiogen/evaluation': 'objective_eval',
'execute_only': 'evaluate',
'+dataset.evaluate.batch_size': batch_size,
'+metrics.fad.tf.batch_size': 32,
}
# binary for FAD computation: replace this path with your own path
metrics_opts = {
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
}
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
opt2 = {'transformer_lm.two_step_cfg': True}
sub = launcher.bind(opts)
sub.bind_(metrics_opts)
# base objective metrics
sub(opt1, opt2)
@GenerationEvalExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=4, partition=partitions)
if 'REGEN' not in os.environ:
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
with launcher.job_array():
for sig in folder.iterdir():
if not sig.is_symlink():
continue
xp = train.main.get_xp_from_sig(sig.name)
launcher(xp.argv)
return
audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
eval(audiogen_base_medium, batch_size=128)
================================================
FILE: audiocraft/grids/compression/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""EnCodec grids."""
================================================
FILE: audiocraft/grids/compression/_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import treetable as tt
from .._base_explorers import BaseExplorer
class CompressionExplorer(BaseExplorer):
eval_metrics = ["sisnr", "visqol"]
def stages(self):
return ["train", "valid", "evaluate"]
def get_grid_meta(self):
"""Returns the list of Meta information to display for each XP/job.
"""
return [
tt.leaf("index", align=">"),
tt.leaf("name", wrap=140),
tt.leaf("state"),
tt.leaf("sig", align=">"),
]
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table.
"""
return [
tt.group(
"train",
[
tt.leaf("epoch"),
tt.leaf("bandwidth", ".2f"),
tt.leaf("adv", ".4f"),
tt.leaf("d_loss", ".4f"),
],
align=">",
),
tt.group(
"valid",
[
tt.leaf("bandwidth", ".2f"),
tt.leaf("adv", ".4f"),
tt.leaf("msspec", ".4f"),
tt.leaf("sisnr", ".2f"),
],
align=">",
),
tt.group(
"evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
),
]
================================================
FILE: audiocraft/grids/compression/debug.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid is a minimal example for debugging compression task
and how to override parameters directly in a grid.
Learn more about dora grids: https://github.com/facebookresearch/dora
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=2, partition=partitions)
launcher.bind_(solver='compression/debug')
with launcher.job_array():
# base debug task using config from solver=compression/debug
launcher()
# we can override parameters in the grid to launch additional xps
launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
================================================
FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=8, partition=partitions)
# use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
# AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
launcher.bind_(solver='compression/encodec_audiogen_16khz')
# replace this by the desired sound dataset
launcher.bind_(dset='internal/sounds_16khz')
# launch xp
launcher()
================================================
FILE: audiocraft/grids/compression/encodec_base_24khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid shows how to train a base causal EnCodec model at 24 kHz.
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=8, partition=partitions)
# base causal EnCodec trained on monophonic audio sampled at 24 kHz
launcher.bind_(solver='compression/encodec_base_24khz')
# replace this by the desired dataset
launcher.bind_(dset='audio/example')
# launch xp
launcher()
================================================
FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid shows how to train a MusicGen EnCodec model at 32 kHz.
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=8, partition=partitions)
# use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
# MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
launcher.bind_(solver='compression/encodec_musicgen_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')
# launch xp
launcher()
launcher({
'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
'label': 'visqol',
'evaluate.metrics.visqol': True
})
================================================
FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Training of the 4 diffusion models described in
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
(paper link).
"""
from ._explorers import DiffusionExplorer
@DiffusionExplorer
def explorer(launcher):
launcher.slurm_(gpus=4, partition='learnfair')
launcher.bind_({'solver': 'diffusion/default',
'dset': 'internal/music_10k_32khz'})
with launcher.job_array():
launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
================================================
FILE: audiocraft/grids/diffusion/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Diffusion grids."""
================================================
FILE: audiocraft/grids/diffusion/_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import treetable as tt
from .._base_explorers import BaseExplorer
class DiffusionExplorer(BaseExplorer):
eval_metrics = ["sisnr", "visqol"]
def stages(self):
return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
def get_grid_meta(self):
"""Returns the list of Meta information to display for each XP/job.
"""
return [
tt.leaf("index", align=">"),
tt.leaf("name", wrap=140),
tt.leaf("state"),
tt.leaf("sig", align=">"),
]
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table.
"""
return [
tt.group(
"train",
[
tt.leaf("epoch"),
tt.leaf("loss", ".3%"),
],
align=">",
),
tt.group(
"valid",
[
tt.leaf("loss", ".3%"),
# tt.leaf("loss_0", ".3%"),
],
align=">",
),
tt.group(
"valid_ema",
[
tt.leaf("loss", ".3%"),
# tt.leaf("loss_0", ".3%"),
],
align=">",
),
tt.group(
"evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
tt.leaf("rvm_3", ".4f"), ], align=">"
),
tt.group(
"evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
tt.leaf("rvm_3", ".4f")], align=">"
),
]
================================================
FILE: audiocraft/grids/musicgen/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""MusicGen grids."""
================================================
FILE: audiocraft/grids/musicgen/_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import treetable as tt
from .._base_explorers import BaseExplorer
class LMExplorer(BaseExplorer):
eval_metrics: tp.List[str] = []
def stages(self) -> tp.List[str]:
return ['train', 'valid']
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table."""
return [
tt.group(
'train',
[
tt.leaf('epoch'),
tt.leaf('duration', '.1f'), # duration in minutes
tt.leaf('ping'),
tt.leaf('ce', '.4f'), # cross entropy
tt.leaf("ppl", '.3f'), # perplexity
],
align='>',
),
tt.group(
'valid',
[
tt.leaf('ce', '.4f'),
tt.leaf('ppl', '.3f'),
tt.leaf('best_ppl', '.3f'),
],
align='>',
),
]
def process_sheep(self, sheep, history):
parts = super().process_sheep(sheep, history)
track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
def comparator(mode, a, b):
return a < b if mode == 'lower' else a > b
for metrics in history:
for key, sub in metrics.items():
for metric in track_by:
# for the validation set, keep track of best metrics (ppl in this example)
# this is so we can conveniently compare metrics between runs in the grid
if key == 'valid' and metric in sub and comparator(
track_by[metric], sub[metric], best_metrics[metric]
):
best_metrics[metric] = sub[metric]
if 'valid' in parts:
parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
return parts
class GenerationEvalExplorer(BaseExplorer):
eval_metrics: tp.List[str] = []
def stages(self) -> tp.List[str]:
return ['evaluate']
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table."""
return [
tt.group(
'evaluate',
[
tt.leaf('epoch', '.3f'),
tt.leaf('duration', '.1f'),
tt.leaf('ping'),
tt.leaf('ce', '.4f'),
tt.leaf('ppl', '.3f'),
tt.leaf('fad', '.3f'),
tt.leaf('kld', '.3f'),
tt.leaf('text_consistency', '.3f'),
tt.leaf('chroma_cosine', '.3f'),
],
align='>',
),
]
================================================
FILE: audiocraft/grids/musicgen/musicgen_base_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='musicgen/musicgen_base_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')
fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
large = {'model/lm/model_scale': 'large'}
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
launcher.bind_(fsdp)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind()
sub()
launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind()
sub(medium, adam)
launcher.slurm_(gpus=96).bind_(label='96gpus')
with launcher.job_array():
sub = launcher.bind()
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
================================================
FILE: audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='musicgen/musicgen_base_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')
fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
large = {'model/lm/model_scale': 'large'}
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
# BEGINNING OF CACHE WRITING JOBS.
cache_write = {
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
'cache.write': True,
'generate.every': 500,
'evaluate.every': 500,
'logging.log_updates': 50,
}
cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
cache_sub.bind_({'deadlock.use': True})
cache_sub.slurm_(gpus=8)
with launcher.job_array():
num_shards = 10 # total number of jobs running in parallel.
for shard in range(0, num_shards):
launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
# REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
# OR SUFFICIENTLY AHEAD.
return
cache = {
'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
}
launcher.bind_(fsdp, cache)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind()
sub()
launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind()
sub(medium, adam)
launcher.slurm_(gpus=96).bind_(label='96gpus')
with launcher.job_array():
sub = launcher.bind()
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
================================================
FILE: audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='musicgen/musicgen_base_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')
launcher.bind_(conditioner='clapemb2music')
fsdp = {'autocast': False, 'fsdp.use': True}
cache_path = {'conditioners.description.clap.cache_path':
'/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
launcher.bind_(fsdp)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
launcher()
launcher(text_wav_training_opt)
launcher(cache_path)
launcher(cache_path, text_wav_training_opt)
================================================
FILE: audiocraft/grids/musicgen/musicgen_melody_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='musicgen/musicgen_melody_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')
fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
large = {'model/lm/model_scale': 'large'}
cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
wd_low = {'conditioners.description.t5.word_dropout': 0.2}
adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
'/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
# CACHE GENERATION JOBS
n_cache_gen_jobs = 4
gen_sub = launcher.slurm(gpus=1)
gen_sub.bind_(
cache_path, {
# the cache is always computed over the whole file, so duration doesn't matter here.
'dataset.segment_duration': 2.,
'dataset.batch_size': 8,
'dataset.train.permutation_on_files': True, # try to not repeat files.
'optim.epochs': 10,
'model/lm/model_scale': 'xsmall',
})
with gen_sub.job_array():
for gen_job in range(n_cache_gen_jobs):
gen_sub({'dataset.train.shuffle_seed': gen_job})
# ACTUAL TRAINING JOBS.
launcher.bind_(fsdp)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind()
sub()
sub(cache_path)
launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind()
sub(medium, adam)
launcher.slurm_(gpus=96).bind_(label='96gpus')
with launcher.job_array():
sub = launcher.bind()
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
================================================
FILE: audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Evaluation with objective metrics for the pretrained MusicGen models.
This grid takes signature from the training grid and runs evaluation-only stage.
When running the grid for the first time, please use:
REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
Note that you need the proper metrics external libraries setup to use all
the objective metrics activated in this grid. Refer to the README for more information.
"""
import os
from ._explorers import GenerationEvalExplorer
from ...environment import AudioCraftEnvironment
from ... import train
def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
opts = {
'dset': 'audio/musiccaps_32khz',
'solver/musicgen/evaluation': 'objective_eval',
'execute_only': 'evaluate',
'+dataset.evaluate.batch_size': batch_size,
'+metrics.fad.tf.batch_size': 16,
}
# chroma-specific evaluation
chroma_opts = {
'dset': 'internal/music_400k_32khz',
'dataset.evaluate.segment_duration': 30,
'dataset.evaluate.num_samples': 1000,
'evaluate.metrics.chroma_cosine': True,
'evaluate.metrics.fad': False,
'evaluate.metrics.kld': False,
'evaluate.metrics.text_consistency': False,
}
# binary for FAD computation: replace this path with your own path
metrics_opts = {
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
}
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
opt2 = {'transformer_lm.two_step_cfg': True}
sub = launcher.bind(opts)
sub.bind_(metrics_opts)
# base objective metrics
sub(opt1, opt2)
if eval_melody:
# chroma-specific metrics
sub(opt1, opt2, chroma_opts)
@GenerationEvalExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=4, partition=partitions)
if 'REGEN' not in os.environ:
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
with launcher.job_array():
for sig in folder.iterdir():
if not sig.is_symlink():
continue
xp = train.main.get_xp_from_sig(sig.name)
launcher(xp.argv)
return
with launcher.job_array():
musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
# base musicgen models
musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
eval(musicgen_base_small, batch_size=128)
musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
eval(musicgen_base_medium, batch_size=128)
musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
eval(musicgen_base_large, batch_size=128)
# melody musicgen model
musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
================================================
FILE: audiocraft/losses/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loss related classes and functions. In particular the loss balancer from
EnCodec, and the usual spectral losses."""
# flake8: noqa
from .balancer import Balancer
from .sisnr import SISNR
from .stftloss import (
LogSTFTMagnitudeLoss,
MRSTFTLoss,
SpectralConvergenceLoss,
STFTLoss
)
from .specloss import (
MelSpectrogramL1Loss,
MultiScaleMelSpectrogramLoss,
)
================================================
FILE: audiocraft/losses/balancer.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import flashy
import torch
from torch import autograd
class Balancer:
"""Loss balancer.
The loss balancer combines losses together to compute gradients for the backward.
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
(with `avg` an exponential moving average over the updates),
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
standard sum of the partial gradients with the given weights.
A call to the backward method of the balancer will compute the the partial gradients,
combining all the losses and potentially rescaling the gradients,
which can help stabilize the training and reason about multiple losses with varying scales.
The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
Expected usage:
weights = {'loss_a': 1, 'loss_b': 4}
balancer = Balancer(weights, ...)
losses: dict = {}
losses['loss_a'] = compute_loss_a(x, y)
losses['loss_b'] = compute_loss_b(x, y)
if model.training():
effective_loss = balancer.backward(losses, x)
Args:
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
from the backward method to match the weights keys to assign weight to each of the provided loss.
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
overall gradient, rather than a constant multiplier.
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
emay_decay (float): EMA decay for averaging the norms.
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
when rescaling the gradients.
epsilon (float): Epsilon value for numerical stability.
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
coming from each loss, when calling `backward()`.
"""
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
monitor: bool = False):
self.weights = weights
self.per_batch_item = per_batch_item
self.total_norm = total_norm or 1.
self.averager = flashy.averager(ema_decay or 1.)
self.epsilon = epsilon
self.monitor = monitor
self.balance_grads = balance_grads
self._metrics: tp.Dict[str, tp.Any] = {}
@property
def metrics(self):
return self._metrics
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
"""Compute the backward and return the effective train loss, e.g. the loss obtained from
computing the effective weights. If `balance_grads` is True, the effective weights
are the one that needs to be applied to each gradient to respect the desired relative
scale of gradients coming from each loss.
Args:
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
input (torch.Tensor): the input of the losses, typically the output of the model.
This should be the single point of dependence between the losses
and the model being trained.
"""
norms = {}
grads = {}
for name, loss in losses.items():
# Compute partial derivative of the less with respect to the input.
grad, = autograd.grad(loss, [input], retain_graph=True)
if self.per_batch_item:
# We do not average the gradient over the batch dimension.
dims = tuple(range(1, grad.dim()))
norm = grad.norm(dim=dims, p=2).mean()
else:
norm = grad.norm(p=2)
norms[name] = norm
grads[name] = grad
count = 1
if self.per_batch_item:
count = len(grad)
# Average norms across workers. Theoretically we should average the
# squared norm, then take the sqrt, but it worked fine like that.
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
# We approximate the total norm of the gradient as the sums of the norms.
# Obviously this can be very incorrect if all gradients are aligned, but it works fine.
total = sum(avg_norms.values())
self._metrics = {}
if self.monitor:
# Store the ratio of the total gradient represented by each loss.
for k, v in avg_norms.items():
self._metrics[f'ratio_{k}'] = v / total
total_weights = sum([self.weights[k] for k in avg_norms])
assert total_weights > 0.
desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
out_grad = torch.zeros_like(input)
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
for name, avg_norm in avg_norms.items():
if self.balance_grads:
# g_balanced = g / avg(||g||) * total_norm * desired_ratio
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
else:
# We just do regular weighted sum of the gradients.
scale = self.weights[name]
out_grad.add_(grads[name], alpha=scale)
effective_loss += scale * losses[name].detach()
# Send the computed partial derivative with respect to the output of the model to the model.
input.backward(out_grad)
return effective_loss
================================================
FILE: audiocraft/losses/sisnr.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import torch
from torch import nn
from torch.nn import functional as F
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
*shape, length = a.shape
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(a, (0, tgt_length - length))
strides = list(a.stride())
assert strides[-1] == 1, "data should be contiguous"
strides = strides[:-1] + [stride, 1]
return a.as_strided([*shape, n_frames, kernel_size], strides)
def _center(x: torch.Tensor) -> torch.Tensor:
return x - x.mean(-1, True)
def _norm2(x: torch.Tensor) -> torch.Tensor:
return x.pow(2).sum(-1, True)
class SISNR(nn.Module):
"""SISNR loss.
Input should be [B, C, T], output is scalar.
Args:
sample_rate (int): Sample rate.
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
entire audio only.
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
epsilon (float): Epsilon value for numerical stability.
"""
def __init__(
self,
sample_rate: int = 16000,
segment: tp.Optional[float] = 20,
overlap: float = 0.5,
epsilon: float = torch.finfo(torch.float32).eps,
):
super().__init__()
self.sample_rate = sample_rate
self.segment = segment
self.overlap = overlap
self.epsilon = epsilon
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
B, C, T = ref_sig.shape
assert ref_sig.shape == out_sig.shape
if self.segment is None:
frame = T
stride = T
else:
frame = int(self.segment * self.sample_rate)
stride = int(frame * (1 - self.overlap))
epsilon = self.epsilon * frame # make epsilon prop to frame size.
gt = _unfold(ref_sig, frame, stride)
est = _unfold(out_sig, frame, stride)
if self.segment is None:
assert gt.shape[-1] == 1
gt = _center(gt)
est = _center(est)
dot = torch.einsum("bcft,bcft->bcf", gt, est)
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
noise = est - proj
sisnr = 10 * (
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
)
return -1 * sisnr[..., 0].mean()
================================================
FILE: audiocraft/losses/specloss.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import numpy as np
from torchaudio.transforms import MelSpectrogram
import torch
from torch import nn
from torch.nn import functional as F
from ..modules import pad_for_conv1d
class MelSpectrogramWrapper(nn.Module):
"""Wrapper around MelSpectrogram torchaudio transform providing proper padding
and additional post-processing including log scaling.
Args:
n_mels (int): Number of mel bins.
n_fft (int): Number of fft.
hop_length (int): Hop size.
win_length (int): Window length.
n_mels (int): Number of mel bins.
sample_rate (int): Sample rate.
f_min (float or None): Minimum frequency.
f_max (float or None): Maximum frequency.
log (bool): Whether to scale with log.
normalized (bool): Whether to normalize the melspectrogram.
floor_level (float): Floor level based on human perception (default=1e-5).
"""
def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
super().__init__()
self.n_fft = n_fft
hop_length = int(hop_length)
self.hop_length = hop_length
self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
window_fn=torch.hann_window, center=False)
self.floor_level = floor_level
self.log = log
def forward(self, x):
p = int((self.n_fft - self.hop_length) // 2)
if len(x.shape) == 2:
x = x.unsqueeze(1)
x = F.pad(x, (p, p), "reflect")
# Make sure that all the frames are full.
# The combination of `pad_for_conv1d` and the above padding
# will make the output of size ceil(T / hop).
x = pad_for_conv1d(x, self.n_fft, self.hop_length)
self.mel_transform.to(x.device)
mel_spec = self.mel_transform(x)
B, C, freqs, frame = mel_spec.shape
if self.log:
mel_spec = torch.log10(self.floor_level + mel_spec)
return mel_spec.reshape(B, C * freqs, frame)
class MelSpectrogramL1Loss(torch.nn.Module):
"""L1 Loss on MelSpectrogram.
Args:
sample_rate (int): Sample rate.
n_fft (int): Number of fft.
hop_length (int): Hop size.
win_length (int): Window length.
n_mels (int): Number of mel bins.
f_min (float or None): Minimum frequency.
f_max (float or None): Maximum frequency.
log (bool): Whether to scale with log.
normalized (bool): Whether to normalize the melspectrogram.
floor_level (float): Floor level value based on human perception (default=1e-5).
"""
def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
super().__init__()
self.l1 = torch.nn.L1Loss()
self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
log=log, normalized=normalized, floor_level=floor_level)
def forward(self, x, y):
self.melspec.to(x.device)
s_x = self.melspec(x)
s_y = self.melspec(y)
return self.l1(s_x, s_y)
class MultiScaleMelSpectrogramLoss(nn.Module):
"""Multi-Scale spectrogram loss (msspec).
Args:
sample_rate (int): Sample rate.
range_start (int): Power of 2 to use for the first scale.
range_stop (int): Power of 2 to use for the last scale.
n_mels (int): Number of mel bins.
f_min (float): Minimum frequency.
f_max (float or None): Maximum frequency.
normalized (bool): Whether to normalize the melspectrogram.
alphas (bool): Whether to use alphas as coefficients or not.
floor_level (float): Floor level value based on human perception (default=1e-5).
"""
def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
super().__init__()
l1s = list()
l2s = list()
self.alphas = list()
self.total = 0
self.normalized = normalized
for i in range(range_start, range_end):
l1s.append(
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
log=False, normalized=normalized, floor_level=floor_level))
l2s.append(
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
log=True, normalized=normalized, floor_level=floor_level))
if alphas:
self.alphas.append(np.sqrt(2 ** i - 1))
else:
self.alphas.append(1)
self.total += self.alphas[-1] + 1
self.l1s = nn.ModuleList(l1s)
self.l2s = nn.ModuleList(l2s)
def forward(self, x, y):
loss = 0.0
self.l1s.to(x.device)
self.l2s.to(x.device)
for i in range(len(self.alphas)):
s_x_1 = self.l1s[i](x)
s_y_1 = self.l1s[i](y)
s_x_2 = self.l2s[i](x)
s_y_2 = self.l2s[i](y)
loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
if self.normalized:
loss = loss / self.total
return loss
================================================
FILE: audiocraft/losses/stftloss.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Adapted from MIT code under the original license
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
import typing as tp
import torch
from torch import nn
from torch.nn import functional as F
# TODO: Replace with torchaudio.STFT?
def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
"""Perform STFT and convert to magnitude spectrogram.
Args:
x: Input signal tensor (B, C, T).
fft_size (int): FFT size.
hop_length (int): Hop size.
win_length (int): Window length.
window (torch.Tensor or None): Window function type.
normalized (bool): Whether to normalize the STFT or not.
Returns:
torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
"""
B, C, T = x.shape
x_stft = torch.stft(
x.view(-1, T), fft_size, hop_length, win_length, window,
normalized=normalized, return_complex=True,
)
x_stft = x_stft.view(B, C, *x_stft.shape[1:])
real = x_stft.real
imag = x_stft.imag
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
class SpectralConvergenceLoss(nn.Module):
"""Spectral convergence loss.
"""
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
super().__init__()
self.epsilon = epsilon
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
"""Calculate forward propagation.
Args:
x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
torch.Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
class LogSTFTMagnitudeLoss(nn.Module):
"""Log STFT magnitude loss.
Args:
epsilon (float): Epsilon value for numerical stability.
"""
def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
super().__init__()
self.epsilon = epsilon
def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
"""Calculate forward propagation.
Args:
x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
torch.Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
class STFTLosses(nn.Module):
"""STFT losses.
Args:
n_fft (int): Size of FFT.
hop_length (int): Hop length.
win_length (int): Window length.
window (str): Window function type.
normalized (bool): Whether to use normalized STFT or not.
epsilon (float): Epsilon for numerical stability.
"""
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
window: str = "hann_window", normalized: bool = False,
epsilon: float = torch.finfo(torch.float32).eps):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized
self.register_buffer("window", getattr(torch, window)(win_length))
self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (torch.Tensor): Predicted signal (B, T).
y (torch.Tensor): Groundtruth signal (B, T).
Returns:
torch.Tensor: Spectral convergence loss value.
torch.Tensor: Log STFT magnitude loss value.
"""
x_mag = _stft(x, self.n_fft, self.hop_length,
self.win_length, self.window, self.normalized) # type: ignore
y_mag = _stft(y, self.n_fft, self.hop_length,
self.win_length, self.window, self.normalized) # type: ignore
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class STFTLoss(nn.Module):
"""Single Resolution STFT loss.
Args:
n_fft (int): Nb of FFT.
hop_length (int): Hop length.
win_length (int): Window length.
window (str): Window function type.
normalized (bool): Whether to use normalized STFT or not.
epsilon (float): Epsilon for numerical stability.
factor_sc (float): Coefficient for the spectral loss.
factor_mag (float): Coefficient for the magnitude loss.
"""
def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
window: str = "hann_window", normalized: bool = False,
factor_sc: float = 0.1, factor_mag: float = 0.1,
epsilon: float = torch.finfo(torch.float32).eps):
super().__init__()
self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
self.factor_sc = factor_sc
self.factor_mag = factor_mag
def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
x (torch.Tensor): Predicted signal (B, T).
y (torch.Tensor): Groundtruth signal (B, T).
Returns:
torch.Tensor: Single resolution STFT loss.
"""
sc_loss, mag_loss = self.loss(x, y)
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
class MRSTFTLoss(nn.Module):
"""Multi resolution STFT loss.
Args:
n_ffts (Sequence[int]): Sequence of FFT sizes.
hop_lengths (Sequence[int]): Sequence of hop sizes.
win_lengths (Sequence[int]): Sequence of window lengths.
window (str): Window function type.
factor_sc (float): Coefficient for the spectral loss.
factor_mag (float): Coefficient for the magnitude loss.
normalized (bool): Whether to use normalized STFT or not.
epsilon (float): Epsilon for numerical stability.
"""
def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
factor_sc: float = 0.1, factor_mag: float = 0.1,
normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
super().__init__()
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
self.factor_sc = factor_sc
self.factor_mag = factor_mag
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
x (torch.Tensor): Predicted signal (B, T).
y (torch.Tensor): Groundtruth signal (B, T).
Returns:
torch.Tensor: Multi resolution STFT loss.
"""
sc_loss = torch.Tensor([0.0])
mag_loss = torch.Tensor([0.0])
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return self.factor_sc * sc_loss + self.factor_mag * mag_loss
================================================
FILE: audiocraft/metrics/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
"""
# flake8: noqa
from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
from .chroma_cosinesim import ChromaCosineSimilarityMetric
from .fad import FrechetAudioDistanceMetric
from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
from .rvm import RelativeVolumeMel
from .visqol import ViSQOL
================================================
FILE: audiocraft/metrics/chroma_cosinesim.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torchmetrics
from ..data.audio_utils import convert_audio
from ..modules.chroma import ChromaExtractor
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
"""Chroma cosine similarity metric.
This metric extracts a chromagram for a reference waveform and
a generated waveform and compares each frame using the cosine similarity
function. The output is the mean cosine similarity.
Args:
sample_rate (int): Sample rate used by the chroma extractor.
n_chroma (int): Number of chroma used by the chroma extractor.
radix2_exp (int): Exponent for the chroma extractor.
argmax (bool): Whether the chroma extractor uses argmax.
eps (float): Epsilon for cosine similarity computation.
"""
def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
super().__init__()
self.chroma_sample_rate = sample_rate
self.n_chroma = n_chroma
self.eps = eps
self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
radix2_exp=radix2_exp, argmax=argmax)
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, targets: torch.Tensor,
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
"""Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
if preds.size(0) == 0:
return
assert preds.shape == targets.shape, (
f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
assert preds.size(0) == sizes.size(0), (
f"Number of items in preds ({preds.shape}) mismatch ",
f"with sizes ({sizes.shape})")
assert preds.size(0) == sample_rates.size(0), (
f"Number of items in preds ({preds.shape}) mismatch ",
f"with sample_rates ({sample_rates.shape})")
assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
device = self.weight.device
preds, targets = preds.to(device), targets.to(device) # type: ignore
sample_rate = sample_rates[0].item()
preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
gt_chroma = self.chroma_extractor(targets)
gen_chroma = self.chroma_extractor(preds)
chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
for i in range(len(gt_chroma)):
t = int(chroma_lens[i].item())
cosine_sim = torch.nn.functional.cosine_similarity(
gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
self.weight += torch.tensor(t) # type: ignore
def compute(self) -> float:
"""Computes the average cosine similarty across all generated/target chromagrams pairs."""
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
return (self.cosine_sum / self.weight).item() # type: ignore
================================================
FILE: audiocraft/metrics/clap_consistency.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import typing as tp
import torch
import torchmetrics
from transformers import RobertaTokenizer # type: ignore
from ..data.audio_utils import convert_audio
from ..environment import AudioCraftEnvironment
from ..utils.utils import load_clap_state_dict
try:
import laion_clap # type: ignore
except ImportError:
laion_clap = None
class TextConsistencyMetric(torchmetrics.Metric):
"""Text consistency metric measuring consistency between audio and text pairs."""
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
def compute(self):
raise NotImplementedError("implement how to compute the final metric score.")
class CLAPTextConsistencyMetric(TextConsistencyMetric):
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
well as the generated audio based on them, and define the MCC metric as the average cosine similarity
between these embeddings.
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
"""
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
super().__init__()
if laion_clap is None:
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
self._initialize_model(model_path, model_arch, enable_fusion)
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
self.model_sample_rate = 48_000
load_clap_state_dict(self.model, model_path)
self.model.eval()
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
# we use the default params from CLAP module here as well
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
assert audio.size(0) == len(text), "Number of audio and text samples should match"
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
sample_rate = int(sample_rates[0].item())
# convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
# cosine similarity between the text and the audio embedding
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
self.cosine_sum += cosine_sim.sum(dim=0)
self.weight += torch.tensor(cosine_sim.size(0))
def compute(self):
"""Computes the average cosine similarty across all audio/text pairs."""
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
return (self.cosine_sum / self.weight).item() # type: ignore
================================================
FILE: audiocraft/metrics/fad.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
import os
import subprocess
import tempfile
import typing as tp
from audiocraft.data.audio import audio_write
from audiocraft.data.audio_utils import convert_audio
import flashy
import torch
import torchmetrics
from ..environment import AudioCraftEnvironment
logger = logging.getLogger(__name__)
VGGISH_SAMPLE_RATE = 16_000
VGGISH_CHANNELS = 1
class FrechetAudioDistanceMetric(torchmetrics.Metric):
"""Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
From: D.C. Dowson & B.V. Landau The Fréchet distance between
multivariate normal distributions
https://doi.org/10.1016/0047-259X(82)90077-X
The Fréchet distance between two multivariate gaussians,
`X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
= (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
- 2 * Tr(sqrt(sigma_x*sigma_y)))
To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
We provide the below instructions as reference but we do not guarantee for further support
in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
1. Get the code and models following the repository instructions. We used the steps below:
git clone git@github.com:google-research/google-research.git
git clone git@github.com:tensorflow/models.git
mkdir google-research/tensorflow_models
touch google-research/tensorflow_models/__init__.py
cp -r models/research/audioset google-research/tensorflow_models/
touch google-research/tensorflow_models/audioset/__init__.py
echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
google-research/tensorflow_models/audioset/__init__.py
# we can now remove the tensorflow models repository
# rm -r models
cd google-research
Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
assumes it is placed in the AudioCraft reference dir.
Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
- Update xrange for range in:
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
- Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
`tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
- Update `import vggish_params as params` to `from . import vggish_params as params` in:
https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
- Add flag to provide a given batch size for running the AudioSet model in:
https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
```
flags.DEFINE_integer('batch_size', 64,
'Number of samples in the batch for AudioSet model.')
```
Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
`batch_size=FLAGS.batch_size` to the provided parameters.
2. Follow instructions for the library installation and a valid TensorFlow installation
```
# e.g. instructions from: https://www.tensorflow.org/install/pip
conda install -c conda-forge cudatoolkit=11.8.0
python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
>> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
# Verify install: on a machine with GPU device
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
```
Now install frechet_audio_distance required dependencies:
```
# We assume we already have TensorFlow installed from the above steps
pip install apache-beam numpy scipy tf_slim
```
Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
(you may want to specify --model_ckpt flag pointing to the model's path).
3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
and Tensorflow library path from the above installation steps:
export TF_PYTHON_EXE=""
export TF_LIBRARY_PATH=""
e.g. assuming we have installed everything in a dedicated conda env
with python 3.10 that is currently active:
export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
Finally you may want to export the following variable:
export TF_FORCE_GPU_ALLOW_GROWTH=true
See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
You can save those environment variables in your training conda env, when currently active:
`$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
and the training conda env is named audiocraft:
```
# activate training env
conda activate audiocraft
# get path to all envs
CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
# export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
# optionally:
echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
# you may need to reactivate the audiocraft env for this to take effect
```
Args:
bin (Path or str): Path to installed frechet audio distance code.
model_path (Path or str): Path to Tensorflow checkpoint for the model
used to compute statistics over the embedding beams.
format (str): Audio format used to save files.
log_folder (Path or str, optional): Path where to write process logs.
"""
def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
format: str = "wav", batch_size: tp.Optional[int] = None,
log_folder: tp.Optional[tp.Union[Path, str]] = None):
super().__init__()
self.model_sample_rate = VGGISH_SAMPLE_RATE
self.model_channels = VGGISH_CHANNELS
self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
self.format = format
self.batch_size = batch_size
self.bin = bin
self.tf_env = {"PYTHONPATH": str(self.bin)}
self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
logger.info("Python exe for TF is %s", self.python_path)
if 'TF_LIBRARY_PATH' in os.environ:
self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
logger.info("Env for TF is %r", self.tf_env)
self.reset(log_folder)
self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
"""Reset torchmetrics.Metrics state."""
log_folder = Path(log_folder or tempfile.mkdtemp())
self.tmp_dir = log_folder / 'fad'
self.tmp_dir.mkdir(exist_ok=True)
self.samples_tests_dir = self.tmp_dir / 'tests'
self.samples_tests_dir.mkdir(exist_ok=True)
self.samples_background_dir = self.tmp_dir / 'background'
self.samples_background_dir.mkdir(exist_ok=True)
self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
self.manifest_background = self.tmp_dir / 'files_background.cvs'
self.stats_tests_dir = self.tmp_dir / 'stats_tests'
self.stats_background_dir = self.tmp_dir / 'stats_background'
self.counter = 0
def update(self, preds: torch.Tensor, targets: torch.Tensor,
sizes: torch.Tensor, sample_rates: torch.Tensor,
stems: tp.Optional[tp.List[str]] = None):
"""Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
num_samples = preds.shape[0]
assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
assert stems is None or num_samples == len(set(stems))
for i in range(num_samples):
self.total_files += 1 # type: ignore
self.counter += 1
wav_len = int(sizes[i].item())
sample_rate = int(sample_rates[i].item())
pred_wav = preds[i]
target_wav = targets[i]
pred_wav = pred_wav[..., :wav_len]
target_wav = target_wav[..., :wav_len]
stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
# dump audio files
try:
pred_wav = convert_audio(
pred_wav.unsqueeze(0), from_rate=sample_rate,
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
audio_write(
self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
format=self.format, strategy="peak")
except Exception as e:
logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
try:
# for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
# the original audio when writing it
target_wav = convert_audio(
target_wav.unsqueeze(0), from_rate=sample_rate,
to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
audio_write(
self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
format=self.format, strategy="peak")
except Exception as e:
logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
def _get_samples_name(self, is_background: bool):
return 'background' if is_background else 'tests'
def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
if is_background:
input_samples_dir = self.samples_background_dir
input_filename = self.manifest_background
stats_name = self.stats_background_dir
else:
input_samples_dir = self.samples_tests_dir
input_filename = self.manifest_tests
stats_name = self.stats_tests_dir
beams_name = self._get_samples_name(is_background)
log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
with open(input_filename, "w") as fout:
for path in Path(input_samples_dir).glob(f"*.{self.format}"):
fout.write(f"{str(path)}\n")
cmd = [
self.python_path, "-m",
"frechet_audio_distance.create_embeddings_main",
"--model_ckpt", f"{self.model_path}",
"--input_files", f"{str(input_filename)}",
"--stats", f"{str(stats_name)}",
]
if self.batch_size is not None:
cmd += ["--batch_size", str(self.batch_size)]
logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
env = os.environ
if gpu_index is not None:
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
process = subprocess.Popen(
cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
return process, log_file
def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
cmd = [
self.python_path, "-m", "frechet_audio_distance.compute_fad",
"--test_stats", f"{str(self.stats_tests_dir)}",
"--background_stats", f"{str(self.stats_background_dir)}",
]
logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
env = os.environ
if gpu_index is not None:
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
if result.returncode:
logger.error(
"Error with FAD computation from stats: \n %s \n %s",
result.stdout.decode(), result.stderr.decode()
)
raise RuntimeError("Error while executing FAD computation from stats")
try:
# result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
fad_score = float(result.stdout[4:])
return fad_score
except Exception as e:
raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
beams_name = self._get_samples_name(is_background)
if returncode:
with open(log_file, "r") as f:
error_log = f.read()
logger.error(error_log)
os._exit(1)
else:
logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
def _parallel_create_embedding_beams(self, num_of_gpus: int):
assert num_of_gpus > 0
logger.info("Creating embeddings beams in a parallel manner on different GPUs")
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
tests_beams_code = tests_beams_process.wait()
bg_beams_code = bg_beams_process.wait()
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
def _sequential_create_embedding_beams(self):
logger.info("Creating embeddings beams in a sequential manner")
tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
tests_beams_code = tests_beams_process.wait()
self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
bg_beams_code = bg_beams_process.wait()
self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
@flashy.distrib.rank_zero_only
def _local_compute_frechet_audio_distance(self):
"""Compute Frechet Audio Distance score calling TensorFlow API."""
num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
if num_of_gpus > 1:
self._parallel_create_embedding_beams(num_of_gpus)
else:
self._sequential_create_embedding_beams()
fad_score = self._compute_fad_score(gpu_index=0)
return fad_score
def compute(self) -> float:
"""Compute metrics."""
assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
fad_score = self._local_compute_frechet_audio_distance()
logger.warning(f"FAD score = {fad_score}")
fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
return fad_score
================================================
FILE: audiocraft/metrics/kld.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from functools import partial
import logging
import os
import typing as tp
import torch
import torchmetrics
from ..data.audio_utils import convert_audio
logger = logging.getLogger(__name__)
class _patch_passt_stft:
"""Decorator to patch torch.stft in PaSST."""
def __init__(self):
self.old_stft = torch.stft
def __enter__(self):
# return_complex is a mandatory parameter in latest torch versions
# torch is throwing RuntimeErrors when not set
torch.stft = partial(torch.stft, return_complex=False)
def __exit__(self, *exc):
torch.stft = self.old_stft
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
"""Computes the elementwise KL-Divergence loss between probability distributions
from generated samples and target samples.
Args:
pred_probs (torch.Tensor): Probabilities for each label obtained
from a classifier on generated audio. Expected shape is [B, num_classes].
target_probs (torch.Tensor): Probabilities for each label obtained
from a classifier on target audio. Expected shape is [B, num_classes].
epsilon (float): Epsilon value.
Returns:
kld (torch.Tensor): KLD loss between each generated sample and target pair.
"""
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
return kl_div.sum(-1)
class KLDivergenceMetric(torchmetrics.Metric):
"""Base implementation for KL Divergence metric.
The KL divergence is measured between probability distributions
of class predictions returned by a pre-trained audio classification model.
When the KL-divergence is low, the generated audio is expected to
have similar acoustic characteristics as the reference audio,
according to the classifier.
"""
def __init__(self):
super().__init__()
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
"""Get model output given provided input tensor.
Args:
x (torch.Tensor): Input audio tensor of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
Returns:
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
"""
raise NotImplementedError("implement method to extract label distributions from the model.")
def update(self, preds: torch.Tensor, targets: torch.Tensor,
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
"""Calculates running KL-Divergence loss between batches of audio
preds (generated) and target (ground-truth)
Args:
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
"""
assert preds.shape == targets.shape
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
if preds_probs is not None and targets_probs is not None:
assert preds_probs.shape == targets_probs.shape
kld_scores = kl_divergence(preds_probs, targets_probs)
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
self.kld_pq_sum += torch.sum(kld_scores)
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
self.kld_qp_sum += torch.sum(kld_qp_scores)
self.weight += torch.tensor(kld_scores.size(0))
def compute(self) -> dict:
"""Computes KL-Divergence across all evaluated pred/target pairs."""
weight: float = float(self.weight.item()) # type: ignore
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
logger.info(f"Computing KL divergence on a total of {weight} samples")
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
kld_both = kld_pq + kld_qp
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
class PasstKLDivergenceMetric(KLDivergenceMetric):
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
From: PaSST: Efficient Training of Audio Transformers with Patchout
Paper: https://arxiv.org/abs/2110.05069
Implementation: https://github.com/kkoutini/PaSST
Follow instructions from the github repo:
```
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
```
Args:
pretrained_length (float, optional): Audio duration used for the pretrained model.
"""
def __init__(self, pretrained_length: tp.Optional[float] = None):
super().__init__()
self._initialize_model(pretrained_length)
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
"""Initialize underlying PaSST audio classifier."""
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
self.min_input_frames = min_frames
self.max_input_frames = max_frames
self.model_sample_rate = sr
self.model = model
self.model.eval()
self.model.to(self.device)
def _load_base_model(self, pretrained_length: tp.Optional[float]):
"""Load pretrained model from PaSST."""
try:
if pretrained_length == 30:
from hear21passt.base30sec import get_basic_model # type: ignore
max_duration = 30
elif pretrained_length == 20:
from hear21passt.base20sec import get_basic_model # type: ignore
max_duration = 20
else:
from hear21passt.base import get_basic_model # type: ignore
# Original PASST was trained on AudioSet with 10s-long audio samples
max_duration = 10
min_duration = 0.15
min_duration = 0.15
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install hear21passt to compute KL divergence: ",
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
)
model_sample_rate = 32_000
max_input_frames = int(max_duration * model_sample_rate)
min_input_frames = int(min_duration * model_sample_rate)
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
model = get_basic_model(mode='logits')
return model, model_sample_rate, max_input_frames, min_input_frames
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
"""Process audio to feed to the pretrained model."""
wav = wav.unsqueeze(0)
wav = wav[..., :wav_len]
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
wav = wav.squeeze(0)
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
segments = torch.split(wav, self.max_input_frames, dim=-1)
valid_segments = []
for s in segments:
# ignoring too small segments that are breaking the model inference
if s.size(-1) > self.min_input_frames:
valid_segments.append(s)
return [s[None] for s in valid_segments]
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
"""Run the pretrained model and get the predictions."""
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
wav = wav.mean(dim=1)
# PaSST is printing a lot of garbage that we are not interested in
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
with torch.no_grad(), _patch_passt_stft():
logits = self.model(wav.to(self.device))
probs = torch.softmax(logits, dim=-1)
return probs
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
"""Get model output given provided input tensor.
Args:
x (torch.Tensor): Input audio tensor of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
Returns:
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
"""
all_probs: tp.List[torch.Tensor] = []
for i, wav in enumerate(x):
sample_rate = int(sample_rates[i].item())
wav_len = int(sizes[i].item())
wav_segments = self._process_audio(wav, sample_rate, wav_len)
for segment in wav_segments:
probs = self._get_model_preds(segment).mean(dim=0)
all_probs.append(probs)
if len(all_probs) > 0:
return torch.stack(all_probs, dim=0)
else:
return None
================================================
FILE: audiocraft/metrics/rvm.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
from torch import nn
import torchaudio
def db_to_scale(volume: tp.Union[float, torch.Tensor]):
return 10 ** (volume / 20)
def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
min_scale = db_to_scale(min_volume)
return 20 * torch.log10(scale.clamp(min=min_scale))
class RelativeVolumeMel(nn.Module):
"""Relative volume melspectrogram measure.
Computes a measure of distance over two mel spectrogram that is interpretable in terms
of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
first renormalize both by the ground truth of `x_ref`.
..Warning:: This class returns the volume of the distortion at the spectrogram level,
e.g. low negative values reflects lower distortion levels. For a SNR (like reported
in the MultiBandDiffusion paper), just take `-rvm`.
Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
good (for a neural network output, although sound engineers typically aim for much lower attenuations).
Similarly, anything above +30 dB would just be completely missing the target, and there is no point
in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
in line with what neural nets currently can achieve.
For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
The metric can be aggregated over a given frequency band in order have different insights for
different region of the spectrum. `num_aggregated_bands` controls the number of bands.
..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
is numerically stable when computing its gradient. We thus advise against using it as a training loss.
Args:
sample_rate (int): Sample rate of the input audio.
n_mels (int): Number of mel bands to use.
n_fft (int): Number of frequency bins for the STFT.
hop_length (int): Hop length of the STFT and the mel-spectrogram.
min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
to that amount, to avoid rescaling near silence. Given in dB.
min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
and anything below that will be considered equally.
num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
"""
def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
hop_length: int = 128, min_relative_volume: float = -25,
max_relative_volume: float = 25, max_initial_gain: float = 25,
min_activity_volume: float = -25,
num_aggregated_bands: int = 4) -> None:
super().__init__()
self.melspec = torchaudio.transforms.MelSpectrogram(
n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
normalized=True, sample_rate=sample_rate, power=2)
self.min_relative_volume = min_relative_volume
self.max_relative_volume = max_relative_volume
self.max_initial_gain = max_initial_gain
self.min_activity_volume = min_activity_volume
self.num_aggregated_bands = num_aggregated_bands
def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
"""Compute RVM metric between estimate and reference samples.
Args:
estimate (torch.Tensor): Estimate sample.
ground_truth (torch.Tensor): Reference sample.
Returns:
dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
"""
min_scale = db_to_scale(-self.max_initial_gain)
std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
z_gt = self.melspec(ground_truth / std).sqrt()
z_est = self.melspec(estimate / std).sqrt()
delta = z_gt - z_est
ref_db = scale_to_db(z_gt, self.min_activity_volume)
delta_db = scale_to_db(delta.abs(), min_volume=-120)
relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
dims = list(range(relative_db.dim()))
dims.remove(dims[-2])
losses_per_band = relative_db.mean(dim=dims)
aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
metrics['rvm'] = losses_per_band.mean()
return metrics
================================================
FILE: audiocraft/metrics/visqol.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import csv
import json
import logging
from pathlib import Path
import tempfile
import typing as tp
import subprocess
import shutil
import torch
import torchaudio
logger = logging.getLogger(__name__)
class ViSQOL:
"""ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
instructions available in the open source repository: https://github.com/google/visqol
ViSQOL is capable of running in two modes:
Audio Mode:
When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
Audio mode uses support vector regression, with the maximum range at ~4.75.
Speech Mode:
When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
Input should be resampled to 16kHz.
As part of the speech mode processing, a root mean square implementation for voice activity detection
is performed on the reference signal to determine what parts of the signal have voice activity and
should therefore be included in the comparison. The signal is normalized before performing the voice
activity detection.
Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
Args:
visqol_bin (str): Path to the ViSQOL binary.
mode (str): ViSQOL computation mode, expecting "audio" or "speech".
model (str): Name of the model to use for similarity to quality model.
debug (bool): Whether to also get debug metrics from ViSQOL or not.
"""
SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
self.visqol_bin = str(bin)
self.visqol_mode = mode
self.target_sr = self._get_target_sr(self.visqol_mode)
self.model = model
self.debug = debug
assert Path(self.visqol_model).exists(), \
f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
def _get_target_sr(self, mode: str) -> int:
# returns target sampling rate for the corresponding ViSQOL mode.
if mode not in ViSQOL.SAMPLE_RATES_MODES:
raise ValueError(
f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
)
return ViSQOL.SAMPLE_RATES_MODES[mode]
def _prepare_files(
self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
):
# prepare files for ViSQOL evaluation.
assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
assert len(ref_sig) == len(deg_sig), (
"Expects same number of ref and degraded inputs",
f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
)
# resample audio if needed
if sr != target_sr:
transform = torchaudio.transforms.Resample(sr, target_sr)
pad = int(0.5 * target_sr)
rs_ref = []
rs_deg = []
for i in range(len(ref_sig)):
rs_ref_i = transform(ref_sig[i])
rs_deg_i = transform(deg_sig[i])
if pad_with_silence:
rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
rs_ref.append(rs_ref_i)
rs_deg.append(rs_deg_i)
ref_sig = torch.stack(rs_ref)
deg_sig = torch.stack(rs_deg)
# save audio chunks to tmp dir and create csv
tmp_dir = Path(tempfile.mkdtemp())
try:
tmp_input_csv_path = tmp_dir / "input.csv"
tmp_results_csv_path = tmp_dir / "results.csv"
tmp_debug_json_path = tmp_dir / "debug.json"
with open(tmp_input_csv_path, "w") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["reference", "degraded"])
for i in range(len(ref_sig)):
tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
torchaudio.save(
tmp_ref_filename,
torch.clamp(ref_sig[i], min=-0.99, max=0.99),
sample_rate=target_sr,
bits_per_sample=16,
encoding="PCM_S"
)
torchaudio.save(
tmp_deg_filename,
torch.clamp(deg_sig[i], min=-0.99, max=0.99),
sample_rate=target_sr,
bits_per_sample=16,
encoding="PCM_S"
)
csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
except Exception as e:
logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
return tmp_dir, None, None, None
def _flush_files(self, tmp_dir: tp.Union[Path, str]):
# flush tmp files used to compute ViSQOL.
shutil.rmtree(str(tmp_dir))
def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
# collect results for each evaluated pair and return averaged moslqo score.
with open(results_csv_path, "r") as csv_file:
reader = csv.DictReader(csv_file)
moslqo_scores = [float(row["moslqo"]) for row in reader]
if len(moslqo_scores) > 0:
return sum(moslqo_scores) / len(moslqo_scores)
else:
return 0.0
def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
# collect debug data for the visqol inference.
with open(debug_json_path, "r") as f:
data = json.load(f)
return data
@property
def visqol_model(self):
return f'{self.visqol_bin}/model/{self.model}'
def _run_visqol(
self,
input_csv_path: tp.Union[Path, str],
results_csv_path: tp.Union[Path, str],
debug_csv_path: tp.Optional[tp.Union[Path, str]],
):
input_csv_path = str(input_csv_path)
results_csv_path = str(results_csv_path)
debug_csv_path = str(debug_csv_path)
cmd = [
f'{self.visqol_bin}/bazel-bin/visqol',
'--batch_input_csv', f'{input_csv_path}',
'--results_csv', f'{results_csv_path}'
]
if debug_csv_path is not None:
cmd += ['--output_debug', f'{debug_csv_path}']
if self.visqol_mode == "speech":
cmd += ['--use_speech_mode']
cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
result = subprocess.run(cmd, capture_output=True)
if result.returncode:
logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
raise RuntimeError("Error while executing visqol")
result.check_returncode()
def __call__(
self,
ref_sig: torch.Tensor,
deg_sig: torch.Tensor,
sr: int,
pad_with_silence: bool = False,
):
"""Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
Args:
ref_sig (torch.Tensor): Reference signals as [B, C, T].
deg_sig (torch.Tensor): Degraded signals as [B, C, T].
sr (int): Sample rate of the two audio signals.
pad_with_silence (bool): Whether to pad the file with silences as recommended
in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
Returns:
float: The ViSQOL score or mean score for the batch.
"""
logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
)
try:
if input_csv and results_csv:
self._run_visqol(
input_csv,
results_csv,
debug_json if self.debug else None,
)
mosqol = self._collect_moslqo_score(results_csv)
return mosqol
else:
raise RuntimeError("Something unexpected happened when running VISQOL!")
except Exception as e:
logger.error("Exception occurred when running ViSQOL: %s", e)
finally:
self._flush_files(tmp_dir)
================================================
FILE: audiocraft/models/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
"""
# flake8: noqa
from . import builders, loaders
from .encodec import (
CompressionModel, EncodecModel, DAC,
HFEncodecModel, HFEncodecCompressionModel)
from .audiogen import AudioGen
from .lm import LMModel
from .multibanddiffusion import MultiBandDiffusion
from .musicgen import MusicGen
from .unet import DiffusionUnet
================================================
FILE: audiocraft/models/audiogen.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Main model for using AudioGen. This will combine all the required components
and provide easy access to the generation API.
"""
import typing as tp
import torch
from .encodec import CompressionModel
from .lm import LMModel
from .builders import get_debug_compression_model, get_debug_lm_model
from .loaders import load_compression_model, load_lm_model
from ..data.audio_utils import convert_audio
from ..modules.conditioners import ConditioningAttributes
from ..utils.autocast import TorchAutocast
class AudioGen:
"""AudioGen main model with convenient generation API.
Args:
name (str): name of the model.
compression_model (CompressionModel): Compression model
used to map audio to invertible discrete representations.
lm (LMModel): Language model over discrete representations.
max_duration (float, optional): maximum duration the model can produce,
otherwise, inferred from the training params.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: tp.Optional[float] = None):
self.name = name
self.compression_model = compression_model
self.lm = lm
if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
else:
raise ValueError("You must provide max_duration when building directly AudioGen")
assert max_duration is not None
self.max_duration: float = max_duration
self.device = next(iter(lm.parameters())).device
self.generation_params: dict = {}
self.set_generation_params(duration=5) # 5 seconds by default
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
if self.device.type == 'cpu':
self.autocast = TorchAutocast(enabled=False)
else:
self.autocast = TorchAutocast(
enabled=True, device_type=self.device.type, dtype=torch.float16)
@property
def frame_rate(self) -> float:
"""Roughly the number of AR steps per seconds."""
return self.compression_model.frame_rate
@property
def sample_rate(self) -> int:
"""Sample rate of the generated audio."""
return self.compression_model.sample_rate
@property
def audio_channels(self) -> int:
"""Audio channels of the generated audio."""
return self.compression_model.channels
@staticmethod
def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
"""Return pretrained model, we provide a single model for now:
- facebook/audiogen-medium (1.5B), text to sound,
# see: https://huggingface.co/facebook/audiogen-medium
"""
if device is None:
if torch.cuda.device_count():
device = 'cuda'
else:
device = 'cpu'
if name == 'debug':
# used only for unit tests
compression_model = get_debug_compression_model(device, sample_rate=16000)
lm = get_debug_lm_model(device)
return AudioGen(name, compression_model, lm, max_duration=10)
compression_model = load_compression_model(name, device=device)
lm = load_lm_model(name, device=device)
assert 'self_wav' not in lm.condition_provider.conditioners, \
"AudioGen do not support waveform conditioning for now"
return AudioGen(name, compression_model, lm)
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
top_p: float = 0.0, temperature: float = 1.0,
duration: float = 10.0, cfg_coef: float = 3.0,
two_step_cfg: bool = False, extend_stride: float = 2):
"""Set the generation parameters for AudioGen.
Args:
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
top_k (int, optional): top_k used for sampling. Defaults to 250.
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
instead of batching together the two. This has some impact on how things
are padded but seems to have little impact in practice.
extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
should we extend the audio each time. Larger values will mean less context is
preserved, and shorter value will require extra computations.
"""
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
self.extend_stride = extend_stride
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
'two_step_cfg': two_step_cfg,
}
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
"""Override the default progress callback."""
self._progress_callback = progress_callback
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
"""Generate samples conditioned on text.
Args:
descriptions (list of str): A list of strings used as text conditioning.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
assert prompt_tokens is None
return self._generate_tokens(attributes, prompt_tokens, progress)
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
progress: bool = False) -> torch.Tensor:
"""Generate samples conditioned on audio prompts.
Args:
prompt (torch.Tensor): A batch of waveforms used for continuation.
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
if prompt.dim() == 2:
prompt = prompt[None]
if prompt.dim() != 3:
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
if descriptions is None:
descriptions = [None] * len(prompt)
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
assert prompt_tokens is not None
return self._generate_tokens(attributes, prompt_tokens, progress)
@torch.no_grad()
def _prepare_tokens_and_attributes(
self,
descriptions: tp.Sequence[tp.Optional[str]],
prompt: tp.Optional[torch.Tensor],
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
"""Prepare model inputs.
Args:
descriptions (list of str): A list of strings used as text conditioning.
prompt (torch.Tensor): A batch of waveforms used for continuation.
"""
attributes = [
ConditioningAttributes(text={'description': description})
for description in descriptions]
if prompt is not None:
if descriptions is not None:
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
prompt = prompt.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt)
assert scale is None
else:
prompt_tokens = None
return attributes, prompt_tokens
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
"""Generate discrete audio tokens given audio prompt and/or conditions.
Args:
attributes (list of ConditioningAttributes): Conditions used for generation (here text).
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
Returns:
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
"""
i = 0
prompt_list = attributes[0].text['description']
total_gen_len = int(self.duration * self.frame_rate)
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
current_gen_offset: int = 0
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
generated_tokens += current_gen_offset
if self._progress_callback is not None:
# Note that total_gen_len might be quite wrong depending on the
# codebook pattern used, but with delay it is almost accurate.
self._progress_callback(generated_tokens, total_gen_len)
else:
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
if prompt_tokens is not None:
assert max_prompt_len >= prompt_tokens.shape[-1], \
"Prompt is longer than audio to generate"
callback = None
if progress:
callback = _progress_callback
if self.duration <= self.max_duration:
# generate by sampling from LM, simple case.
with self.autocast:
attributes[0].text['description'] = prompt_list[0]
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
else:
all_tokens = []
if prompt_tokens is None:
prompt_length = 0
else:
all_tokens.append(prompt_tokens)
prompt_length = prompt_tokens.shape[-1]
stride_tokens = int(self.frame_rate * self.extend_stride)
while current_gen_offset + prompt_length < total_gen_len:
time_offset = current_gen_offset / self.frame_rate
chunk_duration = min(self.duration - time_offset, self.max_duration)
max_gen_len = int(chunk_duration * self.frame_rate)
with self.autocast:
if i >= len(prompt_list):
i = len(prompt_list) - 1
attributes[0].text['description'] = prompt_list[i]
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
i = i + 1
if prompt_tokens is None:
all_tokens.append(gen_tokens)
else:
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
prompt_tokens = gen_tokens[:, :, stride_tokens:]
prompt_length = prompt_tokens.shape[-1]
current_gen_offset += stride_tokens
gen_tokens = torch.cat(all_tokens, dim=-1)
# generate audio
assert gen_tokens.dim() == 3
with torch.no_grad():
gen_audio = self.compression_model.decode(gen_tokens, None)
return gen_audio
def to(self, device: str):
self.compression_model.to(device)
self.lm.to(device)
return self
================================================
FILE: audiocraft/models/builders.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
All the functions to build the relevant models and modules
from the Hydra config.
"""
import typing as tp
import audiocraft
import omegaconf
import torch
from .encodec import CompressionModel, EncodecModel
from .lm import LMModel
from ..modules.codebooks_patterns import (
CodebooksPatternProvider,
DelayedPatternProvider,
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
)
from ..modules.conditioners import (
BaseConditioner,
ChromaStemConditioner,
CLAPEmbeddingConditioner,
ConditionFuser,
ConditioningProvider,
LUTConditioner,
T5Conditioner,
)
from .unet import DiffusionUnet
from .. import quantization as qt
from ..utils.utils import dict_from_config
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
klass = {
'no_quant': qt.DummyQuantizer,
'rvq': qt.ResidualVectorQuantizer
}[quantizer]
kwargs = dict_from_config(getattr(cfg, quantizer))
if quantizer != 'no_quant':
kwargs['dimension'] = dimension
return klass(**kwargs)
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
if encoder_name == 'seanet':
kwargs = dict_from_config(getattr(cfg, 'seanet'))
encoder_override_kwargs = kwargs.pop('encoder')
decoder_override_kwargs = kwargs.pop('decoder')
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
return encoder, decoder
else:
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
"""Instantiate a compression model."""
if cfg.compression_model == 'encodec':
kwargs = dict_from_config(getattr(cfg, 'encodec'))
encoder_name = kwargs.pop('autoencoder')
quantizer_name = kwargs.pop('quantizer')
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
frame_rate = kwargs['sample_rate'] // encoder.hop_length
renormalize = kwargs.pop('renormalize', False)
# deprecated params
kwargs.pop('renorm', None)
return EncodecModel(encoder, decoder, quantizer,
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
else:
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
"""Instantiate a transformer LM."""
if cfg.lm_model == 'transformer_lm':
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
n_q = kwargs['n_q']
q_modeling = kwargs.pop('q_modeling', None)
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
fuser = get_condition_fuser(cfg)
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
kwargs['cross_attention'] = True
if codebooks_pattern_cfg.modeling is None:
assert q_modeling is not None, \
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
)
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
return LMModel(
pattern_provider=pattern_provider,
condition_provider=condition_provider,
fuser=fuser,
cfg_dropout=cfg_prob,
cfg_coef=cfg_coef,
attribute_dropout=attribute_dropout,
dtype=getattr(torch, cfg.dtype),
device=cfg.device,
**kwargs
).to(cfg.device)
else:
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
"""Instantiate a conditioning model."""
device = cfg.device
duration = cfg.dataset.segment_duration
cfg = getattr(cfg, 'conditioners')
dict_cfg = {} if cfg is None else dict_from_config(cfg)
conditioners: tp.Dict[str, BaseConditioner] = {}
condition_provider_args = dict_cfg.pop('args', {})
condition_provider_args.pop('merge_text_conditions_p', None)
condition_provider_args.pop('drop_desc_p', None)
for cond, cond_cfg in dict_cfg.items():
model_type = cond_cfg['model']
model_args = cond_cfg[model_type]
if model_type == 't5':
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
elif model_type == 'lut':
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
elif model_type == 'chroma_stem':
conditioners[str(cond)] = ChromaStemConditioner(
output_dim=output_dim,
duration=duration,
device=device,
**model_args
)
elif model_type == 'clap':
conditioners[str(cond)] = CLAPEmbeddingConditioner(
output_dim=output_dim,
device=device,
**model_args
)
else:
raise ValueError(f"Unrecognized conditioning model: {model_type}")
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
return conditioner
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
"""Instantiate a condition fuser object."""
fuser_cfg = getattr(cfg, 'fuser')
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
return fuser
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
"""Instantiate a codebooks pattern provider object."""
pattern_providers = {
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
'valle': VALLEPattern,
'musiclm': MusicLMPattern,
}
name = cfg.modeling
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
klass = pattern_providers[name]
return klass(n_q, **kwargs)
def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
"""Instantiate a debug compression model to be used for unit tests."""
assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
model_ratios = {
16000: [10, 8, 8], # 25 Hz at 16kHz
32000: [10, 8, 16] # 25 Hz at 32kHz
}
ratios: tp.List[int] = model_ratios[sample_rate]
frame_rate = 25
seanet_kwargs: dict = {
'n_filters': 4,
'n_residual_layers': 1,
'dimension': 32,
'ratios': ratios,
}
print(seanet_kwargs)
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
init_x = torch.randn(8, 32, 128)
quantizer(init_x, 1) # initialize kmeans etc.
compression_model = EncodecModel(
encoder, decoder, quantizer,
frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
return compression_model.eval()
def get_diffusion_model(cfg: omegaconf.DictConfig):
# TODO Find a way to infer the channels from dset
channels = cfg.channels
num_steps = cfg.schedule.num_steps
return DiffusionUnet(
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
def get_processor(cfg, sample_rate: int = 24000):
sample_processor = SampleProcessor()
if cfg.use:
kw = dict(cfg)
kw.pop('use')
kw.pop('name')
if cfg.name == "multi_band_processor":
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
return sample_processor
def get_debug_lm_model(device='cpu'):
"""Instantiate a debug LM to be used for unit tests."""
pattern = DelayedPatternProvider(n_q=4)
dim = 16
providers = {
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
}
condition_provider = ConditioningProvider(providers)
fuser = ConditionFuser(
{'cross': ['description'], 'prepend': [],
'sum': [], 'input_interpolate': []})
lm = LMModel(
pattern, condition_provider, fuser,
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
cross_attention=True, causal=True)
return lm.to(device).eval()
def get_wrapped_compression_model(
compression_model: CompressionModel,
cfg: omegaconf.DictConfig) -> CompressionModel:
# more to come.
return compression_model
================================================
FILE: audiocraft/models/encodec.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Compression models or wrapper around existing models.
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
"""
from abc import ABC, abstractmethod
import logging
import math
from pathlib import Path
import typing as tp
import numpy as np
import torch
from torch import nn
from transformers import EncodecModel as HFEncodecModel
from .. import quantization as qt
logger = logging.getLogger()
class CompressionModel(ABC, nn.Module):
"""Base API for all compression model that aim at being used as audio tokenizers
with a language model.
"""
@abstractmethod
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
...
@abstractmethod
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
"""See `EncodecModel.encode`."""
...
@abstractmethod
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
"""See `EncodecModel.decode`."""
...
@abstractmethod
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
...
@property
@abstractmethod
def channels(self) -> int:
...
@property
@abstractmethod
def frame_rate(self) -> float:
...
@property
@abstractmethod
def sample_rate(self) -> int:
...
@property
@abstractmethod
def cardinality(self) -> int:
...
@property
@abstractmethod
def num_codebooks(self) -> int:
...
@property
@abstractmethod
def total_codebooks(self) -> int:
...
@abstractmethod
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer."""
...
@staticmethod
def get_pretrained(
name: str, device: tp.Union[torch.device, str] = 'cpu'
) -> 'CompressionModel':
"""Instantiate a CompressionModel from a given pretrained model.
Args:
name (Path or str): name of the pretrained model. See after.
device (torch.device or str): Device on which the model is loaded.
Pretrained models:
- dac_44khz (https://github.com/descriptinc/descript-audio-codec)
- dac_24khz (same)
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
- your own model on HugginFace. Export instructions to come...
"""
from . import builders, loaders
model: CompressionModel
if name in ['dac_44khz', 'dac_24khz']:
model_type = name.split('_')[1]
logger.info("Getting pretrained compression model from DAC %s", model_type)
model = DAC(model_type)
elif name in ['debug_compression_model']:
logger.info("Getting pretrained compression model for debug")
model = builders.get_debug_compression_model()
elif Path(name).exists():
# We assume here if the paths exist that it is in fact an AC checkpoint
# that was exported using `audiocraft.utils.export` functions.
model = loaders.load_compression_model(name, device=device)
else:
logger.info("Getting pretrained compression model from HF %s", name)
hf_model = HFEncodecModel.from_pretrained(name)
model = HFEncodecCompressionModel(hf_model).to(device)
return model.to(device).eval()
class EncodecModel(CompressionModel):
"""Encodec model operating on the raw waveform.
Args:
encoder (nn.Module): Encoder network.
decoder (nn.Module): Decoder network.
quantizer (qt.BaseQuantizer): Quantizer network.
frame_rate (int): Frame rate for the latent representation.
sample_rate (int): Audio sample rate.
channels (int): Number of audio channels.
causal (bool): Whether to use a causal version of the model.
renormalize (bool): Whether to renormalize the audio before running the model.
"""
# we need assignment to override the property in the abstract class,
# I couldn't find a better way...
frame_rate: float = 0
sample_rate: int = 0
channels: int = 0
def __init__(self,
encoder: nn.Module,
decoder: nn.Module,
quantizer: qt.BaseQuantizer,
frame_rate: int,
sample_rate: int,
channels: int,
causal: bool = False,
renormalize: bool = False):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.quantizer = quantizer
self.frame_rate = frame_rate
self.sample_rate = sample_rate
self.channels = channels
self.renormalize = renormalize
self.causal = causal
if self.causal:
# we force disabling here to avoid handling linear overlap of segments
# as supported in original EnCodec codebase.
assert not self.renormalize, 'Causal model does not support renormalize'
@property
def total_codebooks(self):
"""Total number of quantizer codebooks available."""
return self.quantizer.total_codebooks
@property
def num_codebooks(self):
"""Active number of codebooks used by the quantizer."""
return self.quantizer.num_codebooks
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer."""
self.quantizer.set_num_codebooks(n)
@property
def cardinality(self):
"""Cardinality of each codebook."""
return self.quantizer.bins
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
scale: tp.Optional[torch.Tensor]
if self.renormalize:
mono = x.mean(dim=1, keepdim=True)
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
scale = 1e-8 + volume
x = x / scale
scale = scale.view(-1, 1)
else:
scale = None
return x, scale
def postprocess(self,
x: torch.Tensor,
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
if scale is not None:
assert self.renormalize
x = x * scale.view(-1, 1, 1)
return x
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
assert x.dim() == 3
length = x.shape[-1]
x, scale = self.preprocess(x)
emb = self.encoder(x)
q_res = self.quantizer(emb, self.frame_rate)
out = self.decoder(q_res.x)
# remove extra padding added by the encoder and decoder
assert out.shape[-1] >= length, (out.shape[-1], length)
out = out[..., :length]
q_res.x = self.postprocess(out, scale)
return q_res
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
"""Encode the given input tensor to quantized representation along with scale parameter.
Args:
x (torch.Tensor): Float tensor of shape [B, C, T]
Returns:
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
scale a float tensor containing the scale for audio renormalizealization.
"""
assert x.dim() == 3
x, scale = self.preprocess(x)
emb = self.encoder(x)
codes = self.quantizer.encode(emb)
return codes, scale
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
"""Decode the given codes to a reconstructed representation, using the scale to perform
audio denormalization if needed.
Args:
codes (torch.Tensor): Int tensor of shape [B, K, T]
scale (torch.Tensor, optional): Float tensor containing the scale value.
Returns:
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
"""
emb = self.decode_latent(codes)
out = self.decoder(emb)
out = self.postprocess(out, scale)
# out contains extra padding added by the encoder and decoder
return out
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
return self.quantizer.decode(codes)
class DAC(CompressionModel):
def __init__(self, model_type: str = "44khz"):
super().__init__()
try:
import dac.utils
except ImportError:
raise RuntimeError("Could not import dac, make sure it is installed, "
"please run `pip install descript-audio-codec`")
self.model = dac.utils.load_model(model_type=model_type)
self.n_quantizers = self.total_codebooks
self.model.eval()
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
# We don't support training with this.
raise NotImplementedError("Forward and training with DAC not supported.")
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
codes = self.model.encode(x, self.n_quantizers)[1]
return codes, None
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
assert scale is None
z_q = self.decode_latent(codes)
return self.model.decode(z_q)
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
return self.model.quantizer.from_codes(codes)[0]
@property
def channels(self) -> int:
return 1
@property
def frame_rate(self) -> float:
return self.model.sample_rate / self.model.hop_length
@property
def sample_rate(self) -> int:
return self.model.sample_rate
@property
def cardinality(self) -> int:
return self.model.codebook_size
@property
def num_codebooks(self) -> int:
return self.n_quantizers
@property
def total_codebooks(self) -> int:
return self.model.n_codebooks
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
assert n >= 1
assert n <= self.total_codebooks
self.n_quantizers = n
class HFEncodecCompressionModel(CompressionModel):
"""Wrapper around HuggingFace Encodec.
"""
def __init__(self, model: HFEncodecModel):
super().__init__()
self.model = model
bws = self.model.config.target_bandwidths
num_codebooks = [
bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
for bw in bws
]
deltas = [nc - int(nc) for nc in num_codebooks]
# Checking we didn't do some bad maths and we indeed have integers!
assert all(deltas) <= 1e-3, deltas
self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
self.set_num_codebooks(max(self.possible_num_codebooks))
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
# We don't support training with this.
raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
bandwidth = self.model.config.target_bandwidths[bandwidth_index]
res = self.model.encode(x, None, bandwidth)
assert len(res[0]) == 1
assert len(res[1]) == 1
return res[0][0], res[1][0]
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
if scale is None:
scales = [None] # type: ignore
else:
scales = scale # type: ignore
res = self.model.decode(codes[None], scales)
return res[0]
def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
return self.model.quantizer.decode(codes.transpose(0, 1))
@property
def channels(self) -> int:
return self.model.config.audio_channels
@property
def frame_rate(self) -> float:
hop_length = int(np.prod(self.model.config.upsampling_ratios))
return self.sample_rate / hop_length
@property
def sample_rate(self) -> int:
return self.model.config.sampling_rate
@property
def cardinality(self) -> int:
return self.model.config.codebook_size
@property
def num_codebooks(self) -> int:
return self._num_codebooks
@property
def total_codebooks(self) -> int:
return max(self.possible_num_codebooks)
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
if n not in self.possible_num_codebooks:
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
self._num_codebooks = n
================================================
FILE: audiocraft/models/lm.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from functools import partial
import logging
import math
import typing as tp
import torch
from torch import nn
from ..utils import utils
from ..modules.streaming import StreamingModule, State
from ..modules.transformer import StreamingTransformer, create_norm_fn
from ..modules.conditioners import (
ConditionFuser,
ClassifierFreeGuidanceDropout,
AttributeDropout,
ConditioningProvider,
ConditioningAttributes,
ConditionType,
)
from ..modules.codebooks_patterns import CodebooksPatternProvider
from ..modules.activations import get_activation_fn
logger = logging.getLogger(__name__)
ConditionTensors = tp.Dict[str, ConditionType]
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
"""LM layer initialization.
Inspired from xlformers: https://github.com/fairinternal/xlformers
Args:
method (str): Method name for init function. Valid options are:
'gaussian', 'uniform'.
input_dim (int): Input dimension of the initialized module.
init_depth (int, optional): Optional init depth value used to rescale
the standard deviation if defined.
"""
# Compute std
std = 1 / math.sqrt(input_dim)
# Rescale with depth
if init_depth is not None:
std = std / math.sqrt(2 * init_depth)
if method == 'gaussian':
return partial(
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
)
elif method == 'uniform':
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
else:
raise ValueError("Unsupported layer initialization method")
def init_layer(m: nn.Module,
method: str,
init_depth: tp.Optional[int] = None,
zero_bias_init: bool = False):
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
Args:
m (nn.Module): Module to initialize.
method (str): Method name for the init function.
init_depth (int, optional): Optional init depth value used to rescale
the standard deviation if defined.
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
"""
if isinstance(m, nn.Linear):
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
weight = m.weight.float()
init_fn(weight)
m.weight.data[:] = weight.half()
else:
init_fn(m.weight)
if zero_bias_init and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
weight = m.weight.float()
init_fn(weight)
m.weight.data[:] = weight.half()
else:
init_fn(m.weight)
class ScaledEmbedding(nn.Embedding):
"""Boost learning rate for embeddings (with `scale`).
"""
def __init__(self, *args, lr=None, **kwargs):
super().__init__(*args, **kwargs)
self.lr = lr
def make_optim_group(self):
group = {"params": list(self.parameters())}
if self.lr is not None:
group["lr"] = self.lr
return group
@dataclass
class LMOutput:
# The logits are already re-aligned with the input codes
# hence no extra shift is required, e.g. when computing CE
logits: torch.Tensor # [B, K, T, card]
mask: torch.Tensor # [B, K, T]
class LMModel(StreamingModule):
"""Transformer-based language model on multiple streams of codes.
Args:
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
n_q (int): Number of parallel streams to model.
card (int): Cardinality, vocabulary size.
dim (int): Dimension of the transformer encoder.
num_heads (int): Number of heads for the transformer encoder.
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
norm (str): Normalization method.
norm_first (bool): Use pre-norm instead of post-norm.
emb_lr (float, optional): Embedding-specific learning rate.
bias_proj (bool): Use bias for output projections.
weight_init (str, optional): Method for weight initialization.
depthwise_init (str, optional): Method for depthwise weight initialization.
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
cfg_dropout (float): Classifier-free guidance dropout.
cfg_coef (float): Classifier-free guidance coefficient.
attribute_dropout (dict): Attribute dropout probabilities.
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
**kwargs: Additional parameters for the transformer encoder.
"""
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
**kwargs):
super().__init__()
self.cfg_coef = cfg_coef
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
self.att_dropout = AttributeDropout(p=attribute_dropout)
self.condition_provider = condition_provider
self.fuser = fuser
self.card = card
embed_dim = self.card + 1
self.n_q = n_q
self.dim = dim
self.pattern_provider = pattern_provider
self.two_step_cfg = two_step_cfg
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
if 'activation' in kwargs:
kwargs['activation'] = get_activation_fn(kwargs['activation'])
self.transformer = StreamingTransformer(
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
norm=norm, norm_first=norm_first, **kwargs)
self.out_norm: tp.Optional[nn.Module] = None
if norm_first:
self.out_norm = create_norm_fn(norm, dim)
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
self._init_weights(weight_init, depthwise_init, zero_bias_init)
self._fsdp: tp.Optional[nn.Module]
self.__dict__['_fsdp'] = None
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
"""Initialization of the transformer module weights.
Args:
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
'current' where the depth corresponds to the current layer index or 'global' where the total number
of layer is used as depth. If not set, no depthwise initialization strategy is used.
zero_bias_init (bool): Whether to initialize bias to zero or not.
"""
assert depthwise_init is None or depthwise_init in ['current', 'global']
assert depthwise_init is None or weight_init is not None, \
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
assert not zero_bias_init or weight_init is not None, \
"If 'zero_bias_init', a 'weight_init' method should be provided"
if weight_init is None:
return
for emb_layer in self.emb:
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
for layer_idx, tr_layer in enumerate(self.transformer.layers):
depth = None
if depthwise_init == 'current':
depth = layer_idx + 1
elif depthwise_init == 'global':
depth = len(self.transformer.layers)
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
tr_layer.apply(init_fn)
for linear in self.linears:
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
@property
def special_token_id(self) -> int:
return self.card
@property
def num_codebooks(self) -> int:
return self.n_q
def forward(self, sequence: torch.Tensor,
conditions: tp.List[ConditioningAttributes],
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
"""Apply language model on sequence and conditions.
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
S the sequence steps, return the logits with shape [B, card, K, S].
Args:
indices (torch.Tensor): Indices of the codes to model.
conditions (list of ConditioningAttributes): Conditions to use when modeling
the given codes. Note that when evaluating multiple time with the same conditioning
you should pre-compute those and pass them as `condition_tensors`.
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
tensors, see `conditions`.
Returns:
torch.Tensor: Logits.
"""
B, K, S = sequence.shape
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
if condition_tensors is None:
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
# apply dropout modules
conditions = self.cfg_dropout(conditions)
conditions = self.att_dropout(conditions)
tokenized = self.condition_provider.tokenize(conditions)
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
condition_tensors = self.condition_provider(tokenized)
else:
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
input_, cross_attention_input = self.fuser(input_, condition_tensors)
out = self.transformer(input_, cross_attention_src=cross_attention_input)
if self.out_norm:
out = self.out_norm(out)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
# remove the prefix from the model outputs
if len(self.fuser.fuse2cond['prepend']) > 0:
logits = logits[:, :, -S:]
return logits # [B, K, S, card]
def compute_predictions(
self, codes: torch.Tensor,
conditions: tp.List[ConditioningAttributes],
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
forward using the specified codes interleaving pattern.
Args:
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
K the number of codebooks and T the number of timesteps.
conditions (list of ConditioningAttributes): conditionings to use when modeling
the given codes. Note that when evaluating multiple time with the same conditioning
you should pre-compute those and pass them as `condition_tensors`.
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
tensors, see `conditions`.
Returns:
LMOutput: Language model outputs
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
i.e. the first item corresponds to logits to predict the first code, meaning that
no additional shifting of codes and logits is required.
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
Given the specified interleaving strategies, parts of the logits and codes should
not be considered as valid predictions because of invalid context.
"""
B, K, T = codes.shape
codes = codes.contiguous()
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
pattern = self.pattern_provider.get_pattern(T)
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
codes, self.special_token_id, keep_only_valid_steps=True
)
# apply model on pattern sequence
model = self if self._fsdp is None else self._fsdp
logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
# and provide the corresponding mask over invalid positions of tokens
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
# note: we use nans as special token to make it obvious if we feed unexpected logits
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
logits, float('nan'), keep_only_valid_steps=True
)
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
return LMOutput(logits, logits_mask)
def _sample_next_token(self,
sequence: torch.Tensor,
cfg_conditions: CFGConditions,
unconditional_state: State,
use_sampling: bool = False,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
Args:
sequence (torch.Tensor): Current sequence of shape [B, K, S]
with K corresponding to the number of codebooks and S the number of sequence steps.
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
should be twice the batch size, being the concatenation of the conditions + null conditions.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): classifier free guidance coefficient
Returns:
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
"""
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
if self.two_step_cfg and cfg_conditions != {}:
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
condition_tensors, null_condition_tensors = cfg_conditions
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
state = self.get_streaming_state()
self.set_streaming_state(unconditional_state)
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
unconditional_state.update(self.get_streaming_state())
self.set_streaming_state(state)
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
else:
assert isinstance(cfg_conditions, dict)
condition_tensors = cfg_conditions
if condition_tensors:
# Preparing for CFG, predicting both conditional and unconditional logits.
sequence = torch.cat([sequence, sequence], dim=0)
all_logits = model(
sequence,
conditions=[], condition_tensors=condition_tensors)
if condition_tensors:
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
else:
logits = all_logits
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
logits = logits[..., -1] # [B x K x card]
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
if use_sampling and temp > 0.0:
probs = torch.softmax(logits / temp, dim=-1)
if top_p > 0.0:
next_token = utils.sample_top_p(probs, p=top_p)
elif top_k > 0:
next_token = utils.sample_top_k(probs, k=top_k)
else:
next_token = utils.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token
@torch.no_grad()
def generate(self,
prompt: tp.Optional[torch.Tensor] = None,
conditions: tp.List[ConditioningAttributes] = [],
num_samples: tp.Optional[int] = None,
max_gen_len: int = 256,
use_sampling: bool = True,
temp: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None,
remove_prompts: bool = False,
check: bool = False,
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
be perform in a greedy fashion or using sampling with top K and top P strategies.
Args:
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
max_gen_len (int): Maximum generation length.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coeff (float, optional): Classifier-free guidance coefficient.
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
remove_prompts (bool): Whether to remove prompts from generation or not.
check (bool): Whether to apply further checks on generated sequence.
callback (Callback, optional): Callback function to report generation progress.
Returns:
torch.Tensor: Generated tokens.
"""
assert not self.training, "generation shouldn't be used in training mode."
first_param = next(iter(self.parameters()))
device = first_param.device
# Checking all input shapes are consistent.
possible_num_samples = []
if num_samples is not None:
possible_num_samples.append(num_samples)
elif prompt is not None:
possible_num_samples.append(prompt.shape[0])
elif conditions:
possible_num_samples.append(len(conditions))
else:
possible_num_samples.append(1)
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
num_samples = possible_num_samples[0]
# below we create set of conditions: one conditional and one unconditional
# to do that we merge the regular condition together with the null condition
# we then do 1 forward pass instead of 2.
# the reason for that is two-fold:
# 1. it is about x2 faster than doing 2 forward passes
# 2. avoid the streaming API treating the 2 passes as part of different time steps
# We also support doing two different passes, in particular to ensure that
# the padding structure is exactly the same between train and test.
# With a batch size of 1, this can be slower though.
cfg_conditions: CFGConditions
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if conditions:
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
if two_step_cfg:
cfg_conditions = (
self.condition_provider(self.condition_provider.tokenize(conditions)),
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
)
else:
conditions = conditions + null_conditions
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)
else:
cfg_conditions = {}
if prompt is None:
assert num_samples > 0
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
B, K, T = prompt.shape
start_offset = T
assert start_offset < max_gen_len
pattern = self.pattern_provider.get_pattern(max_gen_len)
# this token is used as default value for codes that are not generated yet
unknown_token = -1
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
# filling the gen_codes with the prompt if needed
gen_codes[..., :start_offset] = prompt
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
# retrieve the start_offset in the sequence:
# it is the first sequence step that contains the `start_offset` timestep
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
assert start_offset_sequence is not None
with self.streaming():
unconditional_state = self.get_streaming_state()
prev_offset = 0
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
for offset in range(start_offset_sequence, gen_sequence_len):
# get current sequence (note that the streaming API is providing the caching over previous offsets)
curr_sequence = gen_sequence[..., prev_offset:offset]
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
if check:
# check coherence between mask and sequence
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
# should never happen as gen_sequence is filled progressively
assert not (curr_sequence == unknown_token).any()
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
cfg_coef=cfg_coef)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
next_token[~valid_mask] = self.special_token_id
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
# (then mask tokens should be left as is as well, which is correct)
gen_sequence[..., offset:offset+1] = torch.where(
gen_sequence[..., offset:offset+1] == unknown_token,
next_token, gen_sequence[..., offset:offset+1]
)
prev_offset = offset
if callback is not None:
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
unconditional_state.clear()
# ensure sequence has been entirely filled
assert not (gen_sequence == unknown_token).any()
# ensure gen_sequence pattern and mask are matching
# which means the gen_sequence is valid according to the pattern
assert (
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
).all()
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
# sanity checks over the returned codes and corresponding masks
assert (out_codes[..., :max_gen_len] != unknown_token).all()
assert (out_mask[..., :max_gen_len] == 1).all()
out_start_offset = start_offset if remove_prompts else 0
out_codes = out_codes[..., out_start_offset:max_gen_len]
# ensure the returned codes are all valid
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
return out_codes
================================================
FILE: audiocraft/models/loaders.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions to load from the checkpoints.
Each checkpoint is a torch.saved dict with the following keys:
- 'xp.cfg': the hydra config as dumped during training. This should be used
to rebuild the object using the audiocraft.models.builders functions,
- 'model_best_state': a readily loadable best state for the model, including
the conditioner. The model obtained from `xp.cfg` should be compatible
with this state dict. In the case of a LM, the encodec model would not be
bundled along but instead provided separately.
Those functions also support loading from a remote location with the Torch Hub API.
They also support overriding some parameters, in particular the device and dtype
of the returned model.
"""
from pathlib import Path
from huggingface_hub import hf_hub_download
import typing as tp
import os
from omegaconf import OmegaConf, DictConfig
import torch
from . import builders
from .encodec import CompressionModel
def get_audiocraft_cache_dir() -> tp.Optional[str]:
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
def _get_state_dict(
file_or_url_or_id: tp.Union[Path, str],
filename: tp.Optional[str] = None,
device='cpu',
cache_dir: tp.Optional[str] = None,
):
if cache_dir is None:
cache_dir = get_audiocraft_cache_dir()
# Return the state dict either from a file or url
file_or_url_or_id = str(file_or_url_or_id)
assert isinstance(file_or_url_or_id, str)
if os.path.isfile(file_or_url_or_id):
return torch.load(file_or_url_or_id, map_location=device)
if os.path.isdir(file_or_url_or_id):
file = f"{file_or_url_or_id}/{filename}"
return torch.load(file, map_location=device)
elif file_or_url_or_id.startswith('https://'):
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
else:
assert filename is not None, "filename needs to be defined if using HF checkpoints"
file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir)
return torch.load(file, map_location=device)
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
if 'pretrained' in pkg:
return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
model = builders.get_compression_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
return model
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
def _delete_param(cfg: DictConfig, full_name: str):
parts = full_name.split('.')
for part in parts[:-1]:
if part in cfg:
cfg = cfg[part]
else:
return
OmegaConf.set_struct(cfg, False)
if parts[-1] in cfg:
del cfg[parts[-1]]
OmegaConf.set_struct(cfg, True)
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
_delete_param(cfg, 'conditioners.args.drop_desc_p')
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
return _get_state_dict(file_or_url_or_id, filename="all_in_one.pt", cache_dir=cache_dir)
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = load_mbd_ckpt(file_or_url_or_id, cache_dir=cache_dir)
models = []
processors = []
cfgs = []
sample_rate = pkg['sample_rate']
for i in range(pkg['n_bands']):
cfg = pkg[i]['cfg']
model = builders.get_diffusion_model(cfg)
model_dict = pkg[i]['model_state']
model.load_state_dict(model_dict)
model.to(device)
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
processor_dict = pkg[i]['processor_state']
processor.load_state_dict(processor_dict)
processor.to(device)
models.append(model)
processors.append(processor)
cfgs.append(cfg)
return models, processors, cfgs
================================================
FILE: audiocraft/models/multibanddiffusion.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Multi Band Diffusion models as described in
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
(paper link).
"""
import typing as tp
import torch
import julius
from .unet import DiffusionUnet
from ..modules.diffusion_schedule import NoiseSchedule
from .encodec import CompressionModel
from ..solvers.compression import CompressionSolver
from .loaders import load_compression_model, load_diffusion_models
class DiffusionProcess:
"""Sampling for a diffusion Model.
Args:
model (DiffusionUnet): Diffusion U-Net model.
noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
"""
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
"""
"""
self.model = model
self.schedule = noise_schedule
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
step_list: tp.Optional[tp.List[int]] = None):
"""Perform one diffusion process to generate one of the bands.
Args:
condition (tensor): The embeddings form the compression model.
initial_noise (tensor): The initial noise to start the process/
"""
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
condition=condition)
class MultiBandDiffusion:
"""Sample from multiple diffusion models.
Args:
DPs (list of DiffusionProcess): Diffusion processes.
codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
"""
def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
self.DPs = DPs
self.codec_model = codec_model
self.device = next(self.codec_model.parameters()).device
@property
def sample_rate(self) -> int:
return self.codec_model.sample_rate
@staticmethod
def get_mbd_musicgen(device=None):
"""Load our diffusion models trained for MusicGen."""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = 'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_musicgen_32khz.th'
name = 'facebook/musicgen-small'
codec_model = load_compression_model(name, device=device)
models, processors, cfgs = load_diffusion_models(path, device=device)
DPs = []
for i in range(len(models)):
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
@staticmethod
def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
device: tp.Optional[tp.Union[torch.device, str]] = None,
n_q: tp.Optional[int] = None):
"""Get the pretrained Models for MultibandDiffusion.
Args:
bw (float): Bandwidth of the compression model.
pretrained (bool): Whether to use / download if necessary the models.
device (torch.device or str, optional): Device on which the models are loaded.
n_q (int, optional): Number of quantizers to use within the compression model.
"""
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
if n_q is not None:
assert n_q in [2, 4, 8]
assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
codec_model = CompressionSolver.model_from_checkpoint(
'//pretrained/facebook/encodec_24khz', device=device)
codec_model.set_num_codebooks(n_q)
codec_model = codec_model.to(device)
path = f'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_comp_{n_q}.pt'
models, processors, cfgs = load_diffusion_models(path, device=device)
DPs = []
for i in range(len(models)):
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
return MultiBandDiffusion(DPs, codec_model)
@torch.no_grad()
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
Args:
wav (torch.Tensor): The audio that we want to extract the conditioning from
sample_rate (int): sample rate of the audio"""
if sample_rate != self.sample_rate:
wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
codes, scale = self.codec_model.encode(wav)
assert scale is None, "Scaled compression models not supported."
emb = self.get_emb(codes)
return emb
@torch.no_grad()
def get_emb(self, codes: torch.Tensor):
"""Get latent representation from the discrete codes
Argrs:
codes (torch.Tensor): discrete tokens"""
emb = self.codec_model.decode_latent(codes)
return emb
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
step_list: tp.Optional[tp.List[int]] = None):
"""Generate Wavform audio from the latent embeddings of the compression model
Args:
emb (torch.Tensor): Conditioning embeddinds
size (none torch.Size): size of the output
if None this is computed from the typical upsampling of the model
step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
"""
if size is None:
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
assert size[0] == emb.size(0)
out = torch.zeros(size).to(self.device)
for DP in self.DPs:
out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
return out
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
"""match the eq to the encodec output by matching the standard deviation of some frequency bands
Args:
wav (torch.Tensor): audio to equalize
ref (torch.Tensor):refenrence audio from which we match the spectrogram.
n_bands (int): number of bands of the eq
strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
"""
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
bands = split(wav)
bands_ref = split(ref)
out = torch.zeros_like(ref)
for i in range(n_bands):
out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
return out
def regenerate(self, wav: torch.Tensor, sample_rate: int):
"""Regenerate a wavform through compression and diffusion regeneration.
Args:
wav (torch.Tensor): Original 'ground truth' audio
sample_rate (int): sample rate of the input (and output) wav
"""
if sample_rate != self.codec_model.sample_rate:
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
size = wav.size()
out = self.generate(emb, size=size)
if sample_rate != self.codec_model.sample_rate:
out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
return out
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
"""Generate Waveform audio with diffusion from the discrete codes.
Args:
tokens (torch.Tensor): discrete codes
n_bands (int): bands for the eq matching.
"""
wav_encodec = self.codec_model.decode(tokens)
condition = self.get_emb(tokens)
wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
================================================
FILE: audiocraft/models/musicgen.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Main model for using MusicGen. This will combine all the required components
and provide easy access to the generation API.
"""
import typing as tp
import warnings
import torch
from .encodec import CompressionModel
from .lm import LMModel
from .builders import get_debug_compression_model, get_debug_lm_model
from .loaders import load_compression_model, load_lm_model
from ..data.audio_utils import convert_audio
from ..modules.conditioners import ConditioningAttributes, WavCondition
from ..utils.autocast import TorchAutocast
MelodyList = tp.List[tp.Optional[torch.Tensor]]
MelodyType = tp.Union[torch.Tensor, MelodyList]
# backward compatible names mapping
_HF_MODEL_CHECKPOINTS_MAP = {
"small": "GrandaddyShmax/musicgen-small",
"medium": "GrandaddyShmax/musicgen-medium",
"large": "GrandaddyShmax/musicgen-large",
"melody": "GrandaddyShmax/musicgen-melody",
}
class MusicGen:
"""MusicGen main model with convenient generation API.
Args:
name (str): name of the model.
compression_model (CompressionModel): Compression model
used to map audio to invertible discrete representations.
lm (LMModel): Language model over discrete representations.
max_duration (float, optional): maximum duration the model can produce,
otherwise, inferred from the training params.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: tp.Optional[float] = None):
self.name = name
self.compression_model = compression_model
self.lm = lm
if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
else:
raise ValueError("You must provide max_duration when building directly MusicGen")
assert max_duration is not None
self.max_duration: float = max_duration
self.device = next(iter(lm.parameters())).device
self.generation_params: dict = {}
self.set_generation_params(duration=15) # 15 seconds by default
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
if self.device.type == 'cpu':
self.autocast = TorchAutocast(enabled=False)
else:
self.autocast = TorchAutocast(
enabled=True, device_type=self.device.type, dtype=torch.float16)
@property
def frame_rate(self) -> float:
"""Roughly the number of AR steps per seconds."""
return self.compression_model.frame_rate
@property
def sample_rate(self) -> int:
"""Sample rate of the generated audio."""
return self.compression_model.sample_rate
@property
def audio_channels(self) -> int:
"""Audio channels of the generated audio."""
return self.compression_model.channels
@staticmethod
def get_pretrained(name: str = 'GrandaddyShmax/musicgen-melody', device=None):
"""Return pretrained model, we provide four models:
- facebook/musicgen-small (300M), text to music,
# see: https://huggingface.co/facebook/musicgen-small
- facebook/musicgen-medium (1.5B), text to music,
# see: https://huggingface.co/facebook/musicgen-medium
- facebook/musicgen-melody (1.5B) text to music and text+melody to music,
# see: https://huggingface.co/facebook/musicgen-melody
- facebook/musicgen-large (3.3B), text to music,
# see: https://huggingface.co/facebook/musicgen-large
"""
if device is None:
if torch.cuda.device_count():
device = 'cuda'
else:
device = 'cpu'
if name == 'debug':
# used only for unit tests
compression_model = get_debug_compression_model(device)
lm = get_debug_lm_model(device)
return MusicGen(name, compression_model, lm, max_duration=30)
lm = load_lm_model(name, device=device)
compression_model = load_compression_model(name, device=device)
if 'self_wav' in lm.condition_provider.conditioners:
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
return MusicGen(name, compression_model, lm)
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
top_p: float = 0.0, temperature: float = 1.0,
duration: float = 30.0, cfg_coef: float = 3.0,
two_step_cfg: bool = False, extend_stride: float = 18):
"""Set the generation parameters for MusicGen.
Args:
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
top_k (int, optional): top_k used for sampling. Defaults to 250.
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
instead of batching together the two. This has some impact on how things
are padded but seems to have little impact in practice.
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
should we extend the audio each time. Larger values will mean less context is
preserved, and shorter value will require extra computations.
"""
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
self.extend_stride = extend_stride
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
'two_step_cfg': two_step_cfg,
}
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
"""Override the default progress callback."""
self._progress_callback = progress_callback
def generate_unconditional(self, num_samples: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
"""Generate samples in an unconditional manner.
Args:
num_samples (int): Number of samples to be generated.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
if return_tokens:
return self.generate_audio(tokens), tokens
return self.generate_audio(tokens)
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
"""Generate samples conditioned on text.
Args:
descriptions (list of str): A list of strings used as text conditioning.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
assert prompt_tokens is None
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
if return_tokens:
return self.generate_audio(tokens), tokens
return self.generate_audio(tokens)
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, melody_sample_rate: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
"""Generate samples conditioned on text and melody.
Args:
descriptions (list of str): A list of strings used as text conditioning.
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
melody conditioning. Should have shape [B, C, T] with B matching the description length,
C=1 or 2. It can be [C, T] if there is a single description. It can also be
a list of [C, T] tensors.
melody_sample_rate: (int): Sample rate of the melody waveforms.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
if isinstance(melody_wavs, torch.Tensor):
if melody_wavs.dim() == 2:
melody_wavs = melody_wavs[None]
if melody_wavs.dim() != 3:
raise ValueError("Melody wavs should have a shape [B, C, T].")
melody_wavs = list(melody_wavs)
else:
for melody in melody_wavs:
if melody is not None:
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
melody_wavs = [
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
if wav is not None else None
for wav in melody_wavs]
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
melody_wavs=melody_wavs)
assert prompt_tokens is None
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
if return_tokens:
return self.generate_audio(tokens), tokens
return self.generate_audio(tokens)
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
progress: bool = False, return_tokens: bool = False) \
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
"""Generate samples conditioned on audio prompts.
Args:
prompt (torch.Tensor): A batch of waveforms used for continuation.
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
if prompt.dim() == 2:
prompt = prompt[None]
if prompt.dim() != 3:
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
if descriptions is None:
descriptions = [None] * len(prompt)
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
assert prompt_tokens is not None
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
if return_tokens:
return self.generate_audio(tokens), tokens
return self.generate_audio(tokens)
@torch.no_grad()
def _prepare_tokens_and_attributes(
self,
descriptions: tp.Sequence[tp.Optional[str]],
prompt: tp.Optional[torch.Tensor],
melody_wavs: tp.Optional[MelodyList] = None,
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
"""Prepare model inputs.
Args:
descriptions (list of str): A list of strings used as text conditioning.
prompt (torch.Tensor): A batch of waveforms used for continuation.
melody_wavs (torch.Tensor, optional): A batch of waveforms
used as melody conditioning. Defaults to None.
"""
attributes = [
ConditioningAttributes(text={'description': description})
for description in descriptions]
if melody_wavs is None:
for attr in attributes:
attr.wav['self_wav'] = WavCondition(
torch.zeros((1, 1, 1), device=self.device),
torch.tensor([0], device=self.device),
sample_rate=[self.sample_rate],
path=[None])
else:
if 'self_wav' not in self.lm.condition_provider.conditioners:
raise RuntimeError("This model doesn't support melody conditioning. "
"Use the `melody` model.")
assert len(melody_wavs) == len(descriptions), \
f"number of melody wavs must match number of descriptions! " \
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
for attr, melody in zip(attributes, melody_wavs):
if melody is None:
attr.wav['self_wav'] = WavCondition(
torch.zeros((1, 1, 1), device=self.device),
torch.tensor([0], device=self.device),
sample_rate=[self.sample_rate],
path=[None])
else:
attr.wav['self_wav'] = WavCondition(
melody[None].to(device=self.device),
torch.tensor([melody.shape[-1]], device=self.device),
sample_rate=[self.sample_rate],
path=[None],
)
if prompt is not None:
if descriptions is not None:
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
prompt = prompt.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt)
assert scale is None
else:
prompt_tokens = None
return attributes, prompt_tokens
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
"""Generate discrete audio tokens given audio prompt and/or conditions.
Args:
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
Returns:
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
"""
i = 0
prompt_list = attributes[0].text['description']
total_gen_len = int(self.duration * self.frame_rate)
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
current_gen_offset: int = 0
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
generated_tokens += current_gen_offset
if current_gen_offset > 0:
generated_tokens += (self.max_duration - self.extend_stride) * self.frame_rate
if self._progress_callback is not None:
# Note that total_gen_len might be quite wrong depending on the
# codebook pattern used, but with delay it is almost accurate.
self._progress_callback(generated_tokens, total_gen_len)
else:
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
if prompt_tokens is not None:
assert max_prompt_len >= prompt_tokens.shape[-1], \
"Prompt is longer than audio to generate"
callback = None
if progress:
callback = _progress_callback
if self.duration <= self.max_duration:
# generate by sampling from LM, simple case.
with self.autocast:
attributes[0].text['description'] = prompt_list[0]
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
else:
# now this gets a bit messier, we need to handle prompts,
# melody conditioning etc.
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
all_tokens = []
if prompt_tokens is None:
prompt_length = 0
else:
all_tokens.append(prompt_tokens)
prompt_length = prompt_tokens.shape[-1]
stride_tokens = int(self.frame_rate * self.extend_stride)
while current_gen_offset + prompt_length < total_gen_len:
time_offset = current_gen_offset / self.frame_rate
chunk_duration = min(self.duration - time_offset, self.max_duration)
max_gen_len = int(chunk_duration * self.frame_rate)
for attr, ref_wav in zip(attributes, ref_wavs):
wav_length = ref_wav.length.item()
if wav_length == 0:
continue
# We will extend the wav periodically if it not long enough.
# we have to do it here rather than in conditioners.py as otherwise
# we wouldn't have the full wav.
initial_position = int(time_offset * self.sample_rate)
wav_target_length = int(self.max_duration * self.sample_rate)
positions = torch.arange(initial_position,
initial_position + wav_target_length, device=self.device)
attr.wav['self_wav'] = WavCondition(
ref_wav[0][..., positions % wav_length],
torch.full_like(ref_wav[1], wav_target_length),
[self.sample_rate] * ref_wav[0].size(0),
[None], [0.])
with self.autocast:
if i >= len(prompt_list):
i = len(prompt_list) - 1
attributes[0].text['description'] = prompt_list[i]
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
i = i + 1
if prompt_tokens is None:
all_tokens.append(gen_tokens)
else:
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
prompt_tokens = gen_tokens[:, :, stride_tokens:]
prompt_length = prompt_tokens.shape[-1]
current_gen_offset += stride_tokens
gen_tokens = torch.cat(all_tokens, dim=-1)
return gen_tokens
def generate_audio(self, gen_tokens: torch.Tensor):
"""Generate Audio from tokens"""
assert gen_tokens.dim() == 3
with torch.no_grad():
gen_audio = self.compression_model.decode(gen_tokens, None)
return gen_audio
def to(self, device: str):
self.compression_model.to(device)
self.lm.to(device)
return self
================================================
FILE: audiocraft/models/unet.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Pytorch Unet Module used for diffusion.
"""
from dataclasses import dataclass
import typing as tp
import torch
from torch import nn
from torch.nn import functional as F
from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
@dataclass
class Output:
sample: torch.Tensor
def get_model(cfg, channels: int, side: int, num_steps: int):
if cfg.model == 'unet':
return DiffusionUnet(
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
else:
raise RuntimeError('Not Implemented')
class ResBlock(nn.Module):
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
dropout: float = 0.):
super().__init__()
stride = 1
padding = dilation * (kernel - stride) // 2
Conv = nn.Conv1d
Drop = nn.Dropout1d
self.norm1 = nn.GroupNorm(norm_groups, channels)
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
self.activation1 = activation()
self.dropout1 = Drop(dropout)
self.norm2 = nn.GroupNorm(norm_groups, channels)
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
self.activation2 = activation()
self.dropout2 = Drop(dropout)
def forward(self, x):
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
return x + h
class DecoderLayer(nn.Module):
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
dropout: float = 0.):
super().__init__()
padding = (kernel - stride) // 2
self.res_blocks = nn.Sequential(
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
for idx in range(res_blocks)])
self.norm = nn.GroupNorm(norm_groups, chin)
ConvTr = nn.ConvTranspose1d
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
self.activation = activation()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.res_blocks(x)
x = self.norm(x)
x = self.activation(x)
x = self.convtr(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
dropout: float = 0.):
super().__init__()
padding = (kernel - stride) // 2
Conv = nn.Conv1d
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
self.norm = nn.GroupNorm(norm_groups, chout)
self.activation = activation()
self.res_blocks = nn.Sequential(
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
for idx in range(res_blocks)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, T = x.shape
stride, = self.conv.stride
pad = (stride - (T % stride)) % stride
x = F.pad(x, (0, pad))
x = self.conv(x)
x = self.norm(x)
x = self.activation(x)
x = self.res_blocks(x)
return x
class BLSTM(nn.Module):
"""BiLSTM with same hidden units as input dim.
"""
def __init__(self, dim, layers=2):
super().__init__()
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
def forward(self, x):
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
return x
class DiffusionUnet(nn.Module):
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
bilstm: bool = False, transformer: bool = False,
codec_dim: tp.Optional[int] = None, **kwargs):
super().__init__()
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.embeddings: tp.Optional[nn.ModuleList] = None
self.embedding = nn.Embedding(num_steps, hidden)
if emb_all_layers:
self.embeddings = nn.ModuleList()
self.condition_embedding: tp.Optional[nn.Module] = None
for d in range(depth):
encoder = EncoderLayer(chin, hidden, **kwargs)
decoder = DecoderLayer(hidden, chin, **kwargs)
self.encoders.append(encoder)
self.decoders.insert(0, decoder)
if emb_all_layers and d > 0:
assert self.embeddings is not None
self.embeddings.append(nn.Embedding(num_steps, hidden))
chin = hidden
hidden = min(int(chin * growth), max_channels)
self.bilstm: tp.Optional[nn.Module]
if bilstm:
self.bilstm = BLSTM(chin)
else:
self.bilstm = None
self.use_transformer = transformer
self.cross_attention = False
if transformer:
self.cross_attention = cross_attention
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
cross_attention=cross_attention)
self.use_codec = False
if codec_dim is not None:
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
self.use_codec = True
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
skips = []
bs = x.size(0)
z = x
view_args = [1]
if type(step) is torch.Tensor:
step_tensor = step
else:
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
for idx, encoder in enumerate(self.encoders):
z = encoder(z)
if idx == 0:
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
elif self.embeddings is not None:
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
skips.append(z)
if self.use_codec: # insert condition in the bottleneck
assert condition is not None, "Model defined for conditionnal generation"
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
assert condition_emb.size(-1) <= 2 * z.size(-1), \
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
if not self.cross_attention:
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
assert z.size() == condition_emb.size()
z += condition_emb
cross_attention_src = None
else:
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
B, T, C = cross_attention_src.shape
positions = torch.arange(T, device=x.device).view(1, -1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
cross_attention_src = cross_attention_src + pos_emb
if self.use_transformer:
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
else:
if self.bilstm is None:
z = torch.zeros_like(z)
else:
z = self.bilstm(z)
for decoder in self.decoders:
s = skips.pop(-1)
z = z[:, :, :s.shape[2]]
z = z + s
z = decoder(z)
z = z[:, :, :x.shape[2]]
return Output(z)
================================================
FILE: audiocraft/modules/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Modules used for building the models."""
# flake8: noqa
from .conv import (
NormConv1d,
NormConv2d,
NormConvTranspose1d,
NormConvTranspose2d,
StreamableConv1d,
StreamableConvTranspose1d,
pad_for_conv1d,
pad1d,
unpad1d,
)
from .lstm import StreamableLSTM
from .seanet import SEANetEncoder, SEANetDecoder
from .transformer import StreamingTransformer
================================================
FILE: audiocraft/modules/activations.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, Callable
class CustomGLU(nn.Module):
"""Custom Gated Linear Unit activation.
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
function (i.e. sigmoid, swish, etc.).
Args:
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
dim (int): the dimension on which to split the input. Default: -1
Shape:
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
dimensions
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
Examples::
>>> m = CustomGLU(nn.Sigmoid())
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
def __init__(self, activation: nn.Module, dim: int = -1):
super(CustomGLU, self).__init__()
self.dim = dim
self.activation = activation
def forward(self, x: Tensor):
assert x.shape[self.dim] % 2 == 0 # M = N / 2
a, b = torch.chunk(x, 2, dim=self.dim)
return a * self.activation(b)
class SwiGLU(CustomGLU):
"""SiLU Gated Linear Unit activation.
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(SwiGLU, self).__init__(nn.SiLU(), dim)
class GeGLU(CustomGLU):
"""GeLU Gated Linear Unit activation.
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(GeGLU, self).__init__(nn.GELU(), dim)
class ReGLU(CustomGLU):
"""ReLU Gated Linear Unit activation.
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(ReGLU, self).__init__(nn.ReLU(), dim)
def get_activation_fn(
activation: Union[str, Callable[[Tensor], Tensor]]
) -> Union[str, Callable[[Tensor], Tensor]]:
"""Helper function to map an activation string to the activation class.
If the supplied activation is not a string that is recognized, the activation is passed back.
Args:
activation (str, or Callable[[Tensor], Tensor]): Activation to check
"""
if isinstance(activation, str):
if activation == "reglu":
return ReGLU()
elif activation == "geglu":
return GeGLU()
elif activation == "swiglu":
return SwiGLU()
return activation
================================================
FILE: audiocraft/modules/chroma.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from einops import rearrange
from librosa import filters
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
class ChromaExtractor(nn.Module):
"""Chroma extraction and quantization.
Args:
sample_rate (int): Sample rate for the chroma extraction.
n_chroma (int): Number of chroma bins for the chroma extraction.
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
nfft (int, optional): Number of FFT.
winlen (int, optional): Window length.
winhop (int, optional): Window hop size.
argmax (bool, optional): Whether to use argmax. Defaults to False.
norm (float, optional): Norm for chroma normalization. Defaults to inf.
"""
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
norm: float = torch.inf):
super().__init__()
self.winlen = winlen or 2 ** radix2_exp
self.nfft = nfft or self.winlen
self.winhop = winhop or (self.winlen // 4)
self.sample_rate = sample_rate
self.n_chroma = n_chroma
self.norm = norm
self.argmax = argmax
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
n_chroma=self.n_chroma)), persistent=False)
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
hop_length=self.winhop, power=2, center=True,
pad=0, normalized=True)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
T = wav.shape[-1]
# in case we are getting a wav that was dropped out (nullified)
# from the conditioner, make sure wav length is no less that nfft
if T < self.nfft:
pad = self.nfft - T
r = 0 if pad % 2 == 0 else 1
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
spec = self.spec(wav).squeeze(1)
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
if self.argmax:
idx = norm_chroma.argmax(-1, keepdim=True)
norm_chroma[:] = 0
norm_chroma.scatter_(dim=-1, index=idx, value=1)
return norm_chroma
================================================
FILE: audiocraft/modules/codebooks_patterns.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
from dataclasses import dataclass
from functools import lru_cache
import logging
import typing as tp
from abc import ABC, abstractmethod
import torch
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
logger = logging.getLogger(__name__)
@dataclass
class Pattern:
"""Base implementation of a pattern over a sequence with multiple codebooks.
The codebook pattern consists in a layout, defining for each sequence step
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
The first item of the pattern is always an empty list in order to properly insert a special token
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
and ``timesteps`` the number of timesteps corresponding to the original sequence.
The pattern provides convenient methods to build and revert interleaved sequences from it:
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
is returned along with a mask indicating valid tokens.
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
to fill and specify invalid positions if needed.
See the dedicated methods for more details.
"""
# Pattern layout, for each sequence step, we have a list of coordinates
# corresponding to the original codebook timestep and position.
# The first list is always an empty list in order to properly insert
# a special token to start with.
layout: PatternLayout
timesteps: int
n_q: int
def __post_init__(self):
assert len(self.layout) > 0
assert self.layout[0] == []
self._validate_layout()
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
def _validate_layout(self):
"""Runs checks on the layout to ensure a valid pattern is defined.
A pattern is considered invalid if:
- Multiple timesteps for a same codebook are defined in the same sequence step
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
(this would mean that we have future timesteps before past timesteps).
"""
q_timesteps = {q: 0 for q in range(self.n_q)}
for s, seq_coords in enumerate(self.layout):
if len(seq_coords) > 0:
qs = set()
for coord in seq_coords:
qs.add(coord.q)
last_q_timestep = q_timesteps[coord.q]
assert coord.t >= last_q_timestep, \
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
q_timesteps[coord.q] = coord.t
# each sequence step contains at max 1 coordinate per codebook
assert len(qs) == len(seq_coords), \
f"Multiple entries for a same codebook are found at step {s}"
@property
def num_sequence_steps(self):
return len(self.layout) - 1
@property
def max_delay(self):
max_t_in_seq_coords = 0
for seq_coords in self.layout[1:]:
for coords in seq_coords:
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
return max_t_in_seq_coords - self.timesteps
@property
def valid_layout(self):
valid_step = len(self.layout) - self.max_delay
return self.layout[:valid_step]
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
and the actual codebook coordinates.
"""
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
if q is not None:
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
coords = []
for s, seq_codes in enumerate(self.layout):
for code in seq_codes:
if code.t == t and (q is None or code.q == q):
coords.append((s, code))
return coords
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
steps_with_timesteps = self.get_steps_with_timestep(t, q)
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
device: tp.Union[torch.device, str] = 'cpu'):
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
Args:
timesteps (int): Maximum number of timesteps steps to consider.
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
device (torch.device or str): Device for created tensors.
Returns:
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
"""
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
# use the proper layout based on whether we limit ourselves to valid steps only or not,
# note that using the valid_layout will result in a truncated sequence up to the valid steps
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
# which will correspond to the index: n_q * timesteps
indexes[:] = n_q * timesteps
# iterate over the pattern and fill scattered indexes and mask
for s, sequence_coords in enumerate(ref_layout):
for coords in sequence_coords:
if coords.t < timesteps:
indexes[coords.q, s] = coords.t + coords.q * timesteps
mask[coords.q, s] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Build sequence corresponding to the pattern from the input tensor z.
The sequence is built using up to sequence_steps if specified, and non-pattern
coordinates are filled with the special token.
Args:
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
"""
B, K, T = z.shape
indexes, mask = self._build_pattern_sequence_scatter_indexes(
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
)
z = z.view(B, -1)
# we append the special token as the last index of our flattened z tensor
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
values = z[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
keep_only_valid_steps: bool = False,
is_model_output: bool = False,
device: tp.Union[torch.device, str] = 'cpu'):
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
from interleaving pattern.
Args:
sequence_steps (int): Sequence steps.
n_q (int): Number of codebooks.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
device (torch.device or str): Device for created tensors.
Returns:
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
timesteps = self.timesteps
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert sequence_steps <= len(ref_layout), \
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
# ensure we take the appropriate indexes to keep the model output from the first special token as well
if is_model_output:
ref_layout = ref_layout[1:]
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
indexes[:] = n_q * sequence_steps
for s, sequence_codes in enumerate(ref_layout):
if s < sequence_steps:
for code in sequence_codes:
if code.t < timesteps:
indexes[code.q, code.t] = s + code.q * sequence_steps
mask[code.q, code.t] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
are filled with the special token.
Args:
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
B, K, S = s.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
)
s = s.view(B, -1)
# we append the special token as the last index of our flattened z tensor
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
values = s[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
"""Revert model logits obtained on a sequence built from the pattern
back to a tensor matching the original sequence.
This method is similar to ``revert_pattern_sequence`` with the following specificities:
1. It is designed to work with the extra cardinality dimension
2. We return the logits for the first sequence item that matches the special_token and
which matching target in the original sequence is the first item of the sequence,
while we skip the last logits as there is no matching target
"""
B, card, K, S = logits.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
)
logits = logits.reshape(B, card, -1)
# we append the special token as the last index of our flattened z tensor
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
values = logits[:, :, indexes.view(-1)]
values = values.view(B, card, K, indexes.shape[-1])
return values, indexes, mask
class CodebooksPatternProvider(ABC):
"""Abstraction around providing pattern for interleaving codebooks.
The CodebooksPatternProvider abstraction allows to implement various strategies to
define interleaving pattern of sequences composed of multiple codebooks. For a given
number of codebooks `n_q`, the pattern provider can generate a specified pattern
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
can be used to construct a new sequence from the original codes respecting the specified
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
being a tuple with the original timestep and codebook to build the new sequence.
Note that all patterns must start with an empty list that is then used to insert a first
sequence step of special tokens in the newly generated sequence.
Args:
n_q (int): number of codebooks.
cached (bool): if True, patterns for a given length are cached. In general
that should be true for efficiency reason to avoid synchronization points.
"""
def __init__(self, n_q: int, cached: bool = True):
assert n_q > 0
self.n_q = n_q
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
@abstractmethod
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern with specific interleaving between codebooks.
Args:
timesteps (int): Total number of timesteps.
"""
raise NotImplementedError()
class DelayedPatternProvider(CodebooksPatternProvider):
"""Provider for delayed pattern across delayed codebooks.
Codebooks are delayed in the sequence and sequence steps will contain codebooks
from different timesteps.
Example:
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
The resulting sequence obtained from the returned pattern is:
[[S, 1, 2, 3, 4],
[S, S, 1, 2, 3],
[S, S, S, 1, 2]]
(with S being a special token)
Args:
n_q (int): Number of codebooks.
delays (list of int, optional): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
flatten_first (int): Flatten the first N timesteps.
empty_initial (int): Prepend with N empty list of coordinates.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
flatten_first: int = 0, empty_initial: int = 0):
super().__init__(n_q)
if delays is None:
delays = list(range(n_q))
self.delays = delays
self.flatten_first = flatten_first
self.empty_initial = empty_initial
assert len(self.delays) == self.n_q
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
max_delay = max(self.delays)
if self.empty_initial:
out += [[] for _ in range(self.empty_initial)]
if self.flatten_first:
for t in range(min(timesteps, self.flatten_first)):
for q in range(self.n_q):
out.append([LayoutCoord(t, q)])
for t in range(self.flatten_first, timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= self.flatten_first:
v.append(LayoutCoord(t_for_q, q))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class ParallelPatternProvider(DelayedPatternProvider):
"""Provider for parallel pattern across codebooks.
This pattern provider is a special case of the delayed pattern with actually no delay,
hence delays=repeat(0, n_q).
Args:
n_q (int): Number of codebooks.
"""
def __init__(self, n_q: int):
super().__init__(n_q, [0] * n_q)
class UnrolledPatternProvider(CodebooksPatternProvider):
"""Provider for unrolling codebooks pattern.
This pattern provider enables to represent the codebook flattened completely or only to some extend
while also specifying a given delay between the flattened codebooks representation, allowing to
unroll the codebooks in the sequence.
Example:
1. Flattening of the codebooks.
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
taking n_q = 3 and timesteps = 4:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
and delays = [0, 3, 3]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, S, 1, S, 2, S, 3, S, 4],
[S, S, S, 1, S, 2, S, 3, S, 4],
[1, 2, 3, S, 4, S, 5, S, 6, S]]
Args:
n_q (int): Number of codebooks.
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
have n_q extra steps for each timestep.
delays (list of int, optional): Delay for each of the codebooks. If not defined,
no delay is added and therefore will default to [0] * ``n_q``.
Note that two codebooks that will be flattened to the same inner step
should have the same delay, otherwise the pattern is considered as invalid.
"""
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if flattening is None:
flattening = list(range(n_q))
if delays is None:
delays = [0] * n_q
assert len(flattening) == n_q
assert len(delays) == n_q
assert sorted(flattening) == flattening
assert sorted(delays) == delays
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
self.max_delay = max(delays)
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
"""Build a flattened codebooks representation as a dictionary of inner step
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
"""
flattened_codebooks: dict = {}
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
if inner_step not in flattened_codebooks:
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
else:
flat_codebook = flattened_codebooks[inner_step]
assert flat_codebook.delay == delay, (
"Delay and flattening between codebooks is inconsistent: ",
"two codebooks flattened to the same position should have the same delay."
)
flat_codebook.codebooks.append(q)
flattened_codebooks[inner_step] = flat_codebook
return flattened_codebooks
@property
def _num_inner_steps(self):
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
"""
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
def num_virtual_steps(self, timesteps: int) -> int:
return timesteps * self._num_inner_steps + 1
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern for delay across codebooks.
Args:
timesteps (int): Total number of timesteps.
"""
# the PatternLayout is built as a tuple of sequence position and list of coordinates
# so that it can be reordered properly given the required delay between codebooks of given timesteps
indexed_out: list = [(-1, [])]
max_timesteps = timesteps + self.max_delay
for t in range(max_timesteps):
# for each timestep, we unroll the flattened codebooks,
# emitting the sequence step with the corresponding delay
for step in range(self._num_inner_steps):
if step in self._flattened_codebooks:
# we have codebooks at this virtual step to emit
step_codebooks = self._flattened_codebooks[step]
t_for_q = t + step_codebooks.delay
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
if t_for_q < max_timesteps and t < max_timesteps:
indexed_out.append((t_for_q, coords))
else:
# there is no codebook in this virtual step so we emit an empty list
indexed_out.append((t, []))
out = [coords for _, coords in sorted(indexed_out)]
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class VALLEPattern(CodebooksPatternProvider):
"""Almost VALL-E style pattern.
We further allow some delays for the codebooks other than the first one.
Args:
n_q (int): Number of codebooks.
delays (list of int, optional): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if delays is None:
delays = [0] * (n_q - 1)
self.delays = delays
assert len(self.delays) == self.n_q - 1
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for t in range(timesteps):
out.append([LayoutCoord(t, 0)])
max_delay = max(self.delays)
for t in range(timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= 0:
v.append(LayoutCoord(t_for_q, q + 1))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class MusicLMPattern(CodebooksPatternProvider):
"""Almost MusicLM style pattern. This is equivalent to full flattening
but in a different order.
Args:
n_q (int): Number of codebooks.
group_by (int): Number of codebooks to group together.
"""
def __init__(self, n_q: int, group_by: int = 2):
super().__init__(n_q)
self.group_by = group_by
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for offset in range(0, self.n_q, self.group_by):
for t in range(timesteps):
for q in range(offset, offset + self.group_by):
out.append([LayoutCoord(t, q)])
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
================================================
FILE: audiocraft/modules/conditioners.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
import logging
import math
from pathlib import Path
import random
import re
import typing as tp
import warnings
import einops
from num2words import num2words
import spacy
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from .chroma import ChromaExtractor
from .streaming import StreamingModule
from .transformer import create_sin_embedding
from ..data.audio import audio_read
from ..data.audio_dataset import SegmentInfo
from ..data.audio_utils import convert_audio
from ..environment import AudioCraftEnvironment
from ..quantization import ResidualVectorQuantizer
from ..utils.autocast import TorchAutocast
from ..utils.cache import EmbeddingCache
from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
logger = logging.getLogger(__name__)
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
class WavCondition(tp.NamedTuple):
wav: torch.Tensor
length: torch.Tensor
sample_rate: tp.List[int]
path: tp.List[tp.Optional[str]] = []
seek_time: tp.List[tp.Optional[float]] = []
class JointEmbedCondition(tp.NamedTuple):
wav: torch.Tensor
text: tp.List[tp.Optional[str]]
length: torch.Tensor
sample_rate: tp.List[int]
path: tp.List[tp.Optional[str]] = []
seek_time: tp.List[tp.Optional[float]] = []
@dataclass
class ConditioningAttributes:
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
def __getitem__(self, item):
return getattr(self, item)
@property
def text_attributes(self):
return self.text.keys()
@property
def wav_attributes(self):
return self.wav.keys()
@property
def joint_embed_attributes(self):
return self.joint_embed.keys()
@property
def attributes(self):
return {
"text": self.text_attributes,
"wav": self.wav_attributes,
"joint_embed": self.joint_embed_attributes,
}
def to_flat_dict(self):
return {
**{f"text.{k}": v for k, v in self.text.items()},
**{f"wav.{k}": v for k, v in self.wav.items()},
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
}
@classmethod
def from_flat_dict(cls, x):
out = cls()
for k, v in x.items():
kind, att = k.split(".")
out[kind][att] = v
return out
class SegmentWithAttributes(SegmentInfo):
"""Base class for all dataclasses that are used for conditioning.
All child classes should implement `to_condition_attributes` that converts
the existing attributes to a dataclass of type ConditioningAttributes.
"""
def to_condition_attributes(self) -> ConditioningAttributes:
raise NotImplementedError()
def nullify_condition(condition: ConditionType, dim: int = 1):
"""Transform an input condition to a null condition.
The way it is done by converting it to a single zero vector similarly
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
Args:
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
dim (int): The dimension that will be truncated (should be the time dimension)
WARNING!: dim should not be the batch dimension!
Returns:
ConditionType: A tuple of null condition and mask
"""
assert dim != 0, "dim cannot be the batch dimension!"
assert isinstance(condition, tuple) and \
isinstance(condition[0], torch.Tensor) and \
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
cond, mask = condition
B = cond.shape[0]
last_dim = cond.dim() - 1
out = cond.transpose(dim, last_dim)
out = 0. * out[..., :1]
out = out.transpose(dim, last_dim)
mask = torch.zeros((B, 1), device=out.device).int()
assert cond.dim() == out.dim()
return out, mask
def nullify_wav(cond: WavCondition) -> WavCondition:
"""Transform a WavCondition to a nullified WavCondition.
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
Args:
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
Returns:
WavCondition: Nullified wav condition.
"""
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
return WavCondition(
wav=null_wav,
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
sample_rate=cond.sample_rate,
path=[None] * cond.wav.shape[0],
seek_time=[None] * cond.wav.shape[0],
)
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
and replacing metadata by dummy attributes.
Args:
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
"""
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
return JointEmbedCondition(
wav=null_wav, text=[None] * len(embed.text),
length=torch.LongTensor([0]).to(embed.wav.device),
sample_rate=embed.sample_rate,
path=[None] * embed.wav.shape[0],
seek_time=[0] * embed.wav.shape[0],
)
class Tokenizer:
"""Base tokenizer implementation
(in case we want to introduce more advances tokenizers in the future).
"""
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
class WhiteSpaceTokenizer(Tokenizer):
"""This tokenizer should be used for natural language descriptions.
For example:
["he didn't, know he's going home.", 'shorter sentence'] =>
[[78, 62, 31, 4, 78, 25, 19, 34],
[59, 77, 0, 0, 0, 0, 0, 0]]
"""
PUNCTUATION = "?:!.,;"
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
lemma: bool = True, stopwords: bool = True) -> None:
self.n_bins = n_bins
self.pad_idx = pad_idx
self.lemma = lemma
self.stopwords = stopwords
try:
self.nlp = spacy.load(language)
except IOError:
spacy.cli.download(language) # type: ignore
self.nlp = spacy.load(language)
@tp.no_type_check
def __call__(self, texts: tp.List[tp.Optional[str]],
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Take a list of strings and convert them to a tensor of indices.
Args:
texts (list[str]): List of strings.
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
Returns:
tuple[torch.Tensor, torch.Tensor]:
- Indices of words in the LUT.
- And a mask indicating where the padding tokens are
"""
output, lengths = [], []
texts = deepcopy(texts)
for i, text in enumerate(texts):
# if current sample doesn't have a certain attribute, replace with pad token
if text is None:
output.append(torch.Tensor([self.pad_idx]))
lengths.append(0)
continue
# convert numbers to words
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
# normalize text
text = self.nlp(text) # type: ignore
# remove stopwords
if self.stopwords:
text = [w for w in text if not w.is_stop] # type: ignore
# remove punctuation
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
# lemmatize if needed
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
texts[i] = " ".join(text)
lengths.append(len(text))
# convert to tensor
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
output.append(tokens)
mask = length_to_mask(torch.IntTensor(lengths)).int()
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
if return_text:
return padded_output, mask, texts # type: ignore
return padded_output, mask
class NoopTokenizer(Tokenizer):
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
split it to ["Jeff", "Buckley"] and return an index per word.
For example:
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
["Metal", "Rock", "Classical"] => [0, 223, 51]
"""
def __init__(self, n_bins: int, pad_idx: int = 0):
self.n_bins = n_bins
self.pad_idx = pad_idx
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
output, lengths = [], []
for text in texts:
# if current sample doesn't have a certain attribute, replace with pad token
if text is None:
output.append(self.pad_idx)
lengths.append(0)
else:
output.append(hash_trick(text, self.n_bins))
lengths.append(1)
tokens = torch.LongTensor(output).unsqueeze(1)
mask = length_to_mask(torch.IntTensor(lengths)).int()
return tokens, mask
class BaseConditioner(nn.Module):
"""Base model for all conditioner modules.
We allow the output dim to be different than the hidden dim for two reasons:
1) keep our LUTs small when the vocab is large;
2) make all condition dims consistent.
Args:
dim (int): Hidden dim of the model.
output_dim (int): Output dim of the conditioner.
"""
def __init__(self, dim: int, output_dim: int):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.output_proj = nn.Linear(dim, output_dim)
def tokenize(self, *args, **kwargs) -> tp.Any:
"""Should be any part of the processing that will lead to a synchronization
point, e.g. BPE tokenization with transfer to the GPU.
The returned value will be saved and return later when calling forward().
"""
raise NotImplementedError()
def forward(self, inputs: tp.Any) -> ConditionType:
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
Outputs a ConditionType, after the input data was embedded as a dense vector.
Returns:
ConditionType:
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
output embedding and D is the dimension of the embedding.
- And a mask indicating where the padding tokens.
"""
raise NotImplementedError()
class TextConditioner(BaseConditioner):
...
class LUTConditioner(TextConditioner):
"""Lookup table TextConditioner.
Args:
n_bins (int): Number of bins.
dim (int): Hidden dim of the model (text-encoder/LUT).
output_dim (int): Output dim of the conditioner.
tokenizer (str): Name of the tokenizer.
pad_idx (int, optional): Index for padding token. Defaults to 0.
"""
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
super().__init__(dim, output_dim)
self.embed = nn.Embedding(n_bins, dim)
self.tokenizer: Tokenizer
if tokenizer == 'whitespace':
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
elif tokenizer == 'noop':
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
else:
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
device = self.embed.weight.device
tokens, mask = self.tokenizer(x)
tokens, mask = tokens.to(device), mask.to(device)
return tokens, mask
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
tokens, mask = inputs
embeds = self.embed(tokens)
embeds = self.output_proj(embeds)
embeds = (embeds * mask.unsqueeze(-1))
return embeds, mask
class T5Conditioner(TextConditioner):
"""T5-based TextConditioner.
Args:
name (str): Name of the T5 model.
output_dim (int): Output dim of the conditioner.
finetune (bool): Whether to fine-tune T5 at train time.
device (str): Device for T5 Conditioner.
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
word_dropout (float, optional): Word dropout probability.
normalize_text (bool, optional): Whether to apply text normalization.
"""
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
"google/flan-t5-xl", "google/flan-t5-xxl"]
MODELS_DIMS = {
"t5-small": 512,
"t5-base": 768,
"t5-large": 1024,
"t5-3b": 1024,
"t5-11b": 1024,
"google/flan-t5-small": 512,
"google/flan-t5-base": 768,
"google/flan-t5-large": 1024,
"google/flan-t5-3b": 1024,
"google/flan-t5-11b": 1024,
}
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
normalize_text: bool = False):
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
super().__init__(self.MODELS_DIMS[name], output_dim)
self.device = device
self.name = name
self.finetune = finetune
self.word_dropout = word_dropout
if autocast_dtype is None or self.device == 'cpu':
self.autocast = TorchAutocast(enabled=False)
if self.device != 'cpu':
logger.warning("T5 has no autocast, this might lead to NaN")
else:
dtype = getattr(torch, autocast_dtype)
assert isinstance(dtype, torch.dtype)
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
# thanks https://gist.github.com/simon-weber/7853144
previous_level = logging.root.manager.disable
logging.disable(logging.ERROR)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
finally:
logging.disable(previous_level)
if finetune:
self.t5 = t5
else:
# this makes sure that the t5 models is not part
# of the saved checkpoint
self.__dict__['t5'] = t5.to(device)
self.normalize_text = normalize_text
if normalize_text:
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
# if current sample doesn't have a certain attribute, replace with empty string
entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
if self.normalize_text:
_, _, entries = self.text_normalizer(entries, return_text=True)
if self.word_dropout > 0. and self.training:
new_entries = []
for entry in entries:
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
new_entries.append(" ".join(words))
entries = new_entries
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
mask = inputs['attention_mask']
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
return inputs
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
mask = inputs['attention_mask']
with torch.set_grad_enabled(self.finetune), self.autocast:
embeds = self.t5(**inputs).last_hidden_state
embeds = self.output_proj(embeds.to(self.output_proj.weight))
embeds = (embeds * mask.unsqueeze(-1))
return embeds, mask
class WaveformConditioner(BaseConditioner):
"""Base class for all conditioners that take a waveform as input.
Classes that inherit must implement `_get_wav_embedding` that outputs
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
factor of the embedding model.
Args:
dim (int): The internal representation dimension.
output_dim (int): Output dimension.
device (tp.Union[torch.device, str]): Device.
"""
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
super().__init__(dim, output_dim)
self.device = device
def tokenize(self, x: WavCondition) -> WavCondition:
wav, length, sample_rate, path, seek_time = x
assert length is not None
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
"""Gets as input a WavCondition and returns a dense embedding."""
raise NotImplementedError()
def _downsampling_factor(self):
"""Returns the downsampling factor of the embedding model."""
raise NotImplementedError()
def forward(self, x: WavCondition) -> ConditionType:
"""Extract condition embedding and mask from a waveform and its metadata.
Args:
x (WavCondition): Waveform condition containing raw waveform and metadata.
Returns:
ConditionType: a dense vector representing the conditioning along with its mask
"""
wav, lengths, *_ = x
with torch.no_grad():
embeds = self._get_wav_embedding(x)
embeds = embeds.to(self.output_proj.weight)
embeds = self.output_proj(embeds)
if lengths is not None:
lengths = lengths / self._downsampling_factor()
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
else:
mask = torch.ones_like(embeds)
embeds = (embeds * mask.unsqueeze(2).to(self.device))
return embeds, mask
class ChromaStemConditioner(WaveformConditioner):
"""Chroma conditioner based on stems.
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
the drums and bass often dominate the chroma leading to the chroma features
not containing information about the melody.
Args:
output_dim (int): Output dimension for the conditioner.
sample_rate (int): Sample rate for the chroma extractor.
n_chroma (int): Number of chroma bins for the chroma extractor.
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
duration (int): duration used during training. This is later used for correct padding
in case we are using chroma as prefix.
match_len_on_eval (bool, optional): if True then all chromas are padded to the training
duration. Defaults to False.
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
Defaults to None.
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
device (tp.Union[torch.device, str], optional): Device for the conditioner.
**kwargs: Additional parameters for the chroma extractor.
"""
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
device: tp.Union[torch.device, str] = 'cpu', **kwargs):
from demucs import pretrained
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
self.sample_rate = sample_rate
self.match_len_on_eval = match_len_on_eval
self.duration = duration
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
stem_sources: list = self.demucs.sources # type: ignore
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
radix2_exp=radix2_exp, **kwargs).to(device)
self.chroma_len = self._get_chroma_len()
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
self.cache = None
if cache_path is not None:
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
compute_embed_fn=self._get_full_chroma_for_cache,
extract_embed_fn=self._extract_chroma_chunk)
def _downsampling_factor(self) -> int:
return self.chroma.winhop
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
"""Load pre-defined waveforms from a json.
These waveforms will be used for chroma extraction during evaluation.
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
"""
if path is None:
return None
logger.info(f"Loading evaluation wavs from {path}")
from audiocraft.data.audio_dataset import AudioDataset
dataset: AudioDataset = AudioDataset.from_meta(
path, segment_duration=self.duration, min_audio_duration=self.duration,
sample_rate=self.sample_rate, channels=1)
if len(dataset) > 0:
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
return eval_wavs
else:
raise ValueError("Could not find evaluation wavs, check lengths of wavs")
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
self.eval_wavs = eval_wavs
def has_eval_wavs(self) -> bool:
return self.eval_wavs is not None
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
"""Sample wavs from a predefined list."""
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
total_eval_wavs = len(self.eval_wavs)
out = self.eval_wavs
if num_samples > total_eval_wavs:
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
return out[torch.randperm(len(out))][:num_samples]
def _get_chroma_len(self) -> int:
"""Get length of chroma during training."""
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
dummy_chr = self.chroma(dummy_wav)
return dummy_chr.shape[1]
@torch.no_grad()
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Get parts of the wav that holds the melody, extracting the main stems from the wav."""
from demucs.apply import apply_model
from demucs.audio import convert_audio
with self.autocast:
wav = convert_audio(
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
stems = apply_model(self.demucs, wav, device=self.device)
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
mix_wav = stems.sum(1) # merge extracted stems to single waveform
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
return mix_wav
@torch.no_grad()
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
"""Extract chroma features from the waveform."""
with self.autocast:
return self.chroma(wav)
@torch.no_grad()
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Compute wav embedding, applying stem and chroma extraction."""
# avoid 0-size tensors when we are working with null conds
if wav.shape[-1] == 1:
return self._extract_chroma(wav)
stems = self._get_stemmed_wav(wav, sample_rate)
chroma = self._extract_chroma(stems)
return chroma
@torch.no_grad()
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
"""Extract chroma from the whole audio waveform at the given path."""
wav, sr = audio_read(path)
wav = wav[None].to(self.device)
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
return chroma
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
"""Extract a chunk of chroma from the full chroma derived from the full waveform."""
wav_length = x.wav.shape[-1]
seek_time = x.seek_time[idx]
assert seek_time is not None, (
"WavCondition seek_time is required "
"when extracting chroma chunks from pre-computed chroma.")
full_chroma = full_chroma.float()
frame_rate = self.sample_rate / self._downsampling_factor()
target_length = int(frame_rate * wav_length / self.sample_rate)
index = int(frame_rate * seek_time)
out = full_chroma[index: index + target_length]
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
return out.to(self.device)
@torch.no_grad()
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
"""Get the wav embedding from the WavCondition.
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
or will rely on the embedding cache to load the pre-computed embedding if relevant.
"""
sampled_wav: tp.Optional[torch.Tensor] = None
if not self.training and self.eval_wavs is not None:
warn_once(logger, "Using precomputed evaluation wavs!")
sampled_wav = self._sample_eval_wavs(len(x.wav))
no_undefined_paths = all(p is not None for p in x.path)
no_nullified_cond = x.wav.shape[-1] > 1
if sampled_wav is not None:
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
elif self.cache is not None and no_undefined_paths and no_nullified_cond:
paths = [Path(p) for p in x.path if p is not None]
chroma = self.cache.get_embed_from_cache(paths, x)
else:
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
if self.match_len_on_eval:
B, T, C = chroma.shape
if T > self.chroma_len:
chroma = chroma[:, :self.chroma_len]
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
elif T < self.chroma_len:
n_repeat = int(math.ceil(self.chroma_len / T))
chroma = chroma.repeat(1, n_repeat, 1)
chroma = chroma[:, :self.chroma_len]
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
return chroma
def tokenize(self, x: WavCondition) -> WavCondition:
"""Apply WavConditioner tokenization and populate cache if needed."""
x = super().tokenize(x)
no_undefined_paths = all(p is not None for p in x.path)
if self.cache is not None and no_undefined_paths:
paths = [Path(p) for p in x.path if p is not None]
self.cache.populate_embed_cache(paths, x)
return x
class JointEmbeddingConditioner(BaseConditioner):
"""Joint embedding conditioning supporting both audio or text conditioning.
Args:
dim (int): Dimension.
output_dim (int): Output dimension.
device (str): Device.
attribute (str): Attribute used by the conditioner.
autocast_dtype (str): Autocast for the conditioner.
quantize (bool): Whether to quantize the CLAP embedding.
n_q (int): Number of residual quantizers (used if quantize is true).
bins (int): Quantizers' codebooks size (used if quantize is true).
kwargs: Additional parameters for residual vector quantizer.
"""
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
n_q: int = 12, bins: int = 1024, **kwargs):
super().__init__(dim=dim, output_dim=output_dim)
self.device = device
self.attribute = attribute
if autocast_dtype is None or device == 'cpu':
self.autocast = TorchAutocast(enabled=False)
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
else:
dtype = getattr(torch, autocast_dtype)
assert isinstance(dtype, torch.dtype)
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
# residual vector quantizer to discretize the conditioned embedding
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
if quantize:
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Get joint embedding in latent space from the inputs.
Returns:
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
and corresponding empty indexes.
"""
raise NotImplementedError()
def forward(self, x: JointEmbedCondition) -> ConditionType:
with self.autocast:
embed, empty_idx = self._get_embed(x)
if self.quantizer is not None:
embed = embed.view(-1, self.dim, 1)
q_res = self.quantizer(embed, frame_rate=1)
out_embed = q_res.x.view(-1, self.dim)
else:
out_embed = embed
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
out_embed = (out_embed * mask.unsqueeze(-1))
return out_embed, mask
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
return x
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
"""Joint Embedding conditioner based on pre-trained CLAP model.
This CLAP-based conditioner supports a caching mechanism
over the computed embeddings for faster training.
Args:
dim (int): Dimension.
output_dim (int): Output dimension.
device (str): Device.
attribute (str): Attribute used by the conditioner.
quantize (bool): Whether to quantize the CLAP embedding.
n_q (int): Number of residual quantizers (used if quantize is true).
bins (int): Quantizers' codebooks size (used if quantize is true).
checkpoint (str): Path to CLAP checkpoint.
model_arch (str): CLAP model architecture.
enable_fusion (bool): Enable fusion for CLAP model.
sample_rate (int): Sample rate used by CLAP model.
max_audio_length (float): Maximum audio length for CLAP model.
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
normalize (bool): Whether to normalize the CLAP embedding.
text_p (float): Probability of using text representation instead of audio at train time.
batch_size (Optional[int]): Batch size for CLAP embedding computation.
autocast_dtype (str): Autocast for the conditioner.
cache_path (Optional[str]): Path for pre-computed embeddings caching.
kwargs: Additional parameters for residual vector quantizer.
"""
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
try:
import laion_clap # type: ignore
except ImportError:
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
load_clap_state_dict(clap_model, checkpoint)
clap_model.eval()
clap_model.to(device)
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
**kwargs)
self.checkpoint = checkpoint
self.enable_fusion = enable_fusion
self.model_arch = model_arch
self.clap: laion_clap.CLAP_Module
self.clap_tokenize: RobertaTokenizer
self.clap_sample_rate = sample_rate
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
self.clap_stride = int(self.clap_sample_rate * audio_stride)
self.batch_size = batch_size or 1
self.normalize = normalize
self.text_p = text_p
self.__dict__['clap_tokenize'] = clap_tokenize
self.__dict__['clap'] = clap_model
self.wav_cache, self.text_cache = None, None
if cache_path is not None:
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
compute_embed_fn=self._get_wav_embedding_for_cache,
extract_embed_fn=self._extract_wav_embedding_chunk)
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
compute_embed_fn=self._get_text_embedding_for_cache)
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
# we use the default params from CLAP module here as well
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
"""Compute text embedding from CLAP model on a given a batch of text.
Args:
text (list[str]): List of text for the batch, with B items.
Returns:
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
"""
with torch.no_grad():
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
return embed.view(embed.size(0), 1, embed.size(-1))
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
x: JointEmbedCondition, idx: int) -> torch.Tensor:
"""Get text embedding function for the cache."""
text = x.text[idx]
text = text if text is not None else ""
return self._compute_text_embedding([text])[0]
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
"""Preprocess wav to expected format by CLAP model.
Args:
wav (torch.Tensor): Audio wav, of shape [B, C, T].
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
sample_rates (list[int]): Sample rates for each sample in the batch
Returns:
torch.Tensor: Audio wav of shape [B, T].
"""
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
if sample_rates is not None:
_wav = []
for i, audio in enumerate(wav):
sr = sample_rates[i]
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
_wav.append(audio)
wav = torch.stack(_wav, dim=0)
wav = wav.mean(dim=1)
return wav
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
"""Compute audio wave embedding from CLAP model.
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
average the resulting embeddings.
Args:
wav (torch.Tensor): Audio wav, of shape [B, C, T].
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
sample_rates (list[int]): Sample rates for each sample in the batch.
reduce_mean (bool): Whether to get the average tensor.
Returns:
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
"""
with torch.no_grad():
wav = self._preprocess_wav(wav, length, sample_rates)
B, T = wav.shape
if T >= self.clap_max_frames:
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
else:
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
wav = einops.rearrange(wav, 'b f t -> (b f) t')
embed_list = []
for i in range(0, wav.size(0), self.batch_size):
_wav = wav[i:i+self.batch_size, ...]
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
embed_list.append(_embed)
embed = torch.cat(embed_list, dim=0)
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
if reduce_mean:
embed = embed.mean(dim=1, keepdim=True)
return embed # [B, F, D] with F=1 if reduce_mean is True
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
x: JointEmbedCondition, idx: int) -> torch.Tensor:
"""Compute audio wave embedding for the cache.
The embedding is computed on a given audio read from file.
Args:
path (str or Path): Path to the full audio file.
Returns:
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
"""
wav, sr = audio_read(path) # [C, T]
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
return embed.squeeze(0) # [F, D]
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
Args:
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
x (JointEmbedCondition): Joint embedding condition for the full batch.
idx (int): Index considered for the given embedding to extract.
Returns:
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
"""
sample_rate = x.sample_rate[idx]
seek_time = x.seek_time[idx]
seek_time = 0. if seek_time is None else seek_time
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
start_offset = int(seek_time * sample_rate // clap_stride)
end_offset = int(end_seek_time * sample_rate // clap_stride)
wav_embed = full_embed[start_offset:end_offset, ...]
wav_embed = wav_embed.mean(dim=0, keepdim=True)
return wav_embed.to(self.device) # [F, D]
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
"""Get CLAP embedding from a batch of text descriptions."""
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
if self.text_cache is not None and no_nullified_cond:
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
paths = [Path(p) for p in x.path if p is not None]
embed = self.text_cache.get_embed_from_cache(paths, x)
else:
text = [xi if xi is not None else "" for xi in x.text]
embed = self._compute_text_embedding(text)
if self.normalize:
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
return embed
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
no_undefined_paths = all(p is not None for p in x.path)
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
paths = [Path(p) for p in x.path if p is not None]
embed = self.wav_cache.get_embed_from_cache(paths, x)
else:
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
if self.normalize:
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
return embed
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
# Trying to limit as much as possible sync points when the cache is warm.
no_undefined_paths = all(p is not None for p in x.path)
if self.wav_cache is not None and no_undefined_paths:
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
paths = [Path(p) for p in x.path if p is not None]
self.wav_cache.populate_embed_cache(paths, x)
if self.text_cache is not None and no_undefined_paths:
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
paths = [Path(p) for p in x.path if p is not None]
self.text_cache.populate_embed_cache(paths, x)
return x
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Extract shared latent representation from either the wav or the text using CLAP."""
# decide whether to use text embedding at train time or not
use_text_embed = random.random() < self.text_p
if self.training and not use_text_embed:
embed = self._get_wav_embedding(x)
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
else:
embed = self._get_text_embedding(x)
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
return embed, empty_idx
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
If the condition is of type "wav", then nullify it using `nullify_condition` function.
If the condition is of any other type, set its value to None.
Works in-place.
"""
if condition_type not in ['text', 'wav', 'joint_embed']:
raise ValueError(
"dropout_condition got an unexpected condition type!"
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
)
if condition not in getattr(sample, condition_type):
raise ValueError(
"dropout_condition received an unexpected condition!"
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
f" but got '{condition}' of type '{condition_type}'!"
)
if condition_type == 'wav':
wav_cond = sample.wav[condition]
sample.wav[condition] = nullify_wav(wav_cond)
elif condition_type == 'joint_embed':
embed = sample.joint_embed[condition]
sample.joint_embed[condition] = nullify_joint_embed(embed)
else:
sample.text[condition] = None
return sample
class DropoutModule(nn.Module):
"""Base module for all dropout modules."""
def __init__(self, seed: int = 1234):
super().__init__()
self.rng = torch.Generator()
self.rng.manual_seed(seed)
class AttributeDropout(DropoutModule):
"""Dropout with a given probability per attribute.
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
must also be dropped.
Args:
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
...
"genre": 0.1,
"artist": 0.5,
"wav": 0.25,
...
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
seed (int, optional): Random seed.
"""
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
super().__init__(seed=seed)
self.active_on_eval = active_on_eval
# construct dict that return the values from p otherwise 0
self.p = {}
for condition_type, probs in p.items():
self.p[condition_type] = defaultdict(lambda: 0, probs)
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
"""
Args:
samples (list[ConditioningAttributes]): List of conditions.
Returns:
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
"""
if not self.training and not self.active_on_eval:
return samples
samples = deepcopy(samples)
for condition_type, ps in self.p.items(): # for condition types [text, wav]
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
if torch.rand(1, generator=self.rng).item() < p:
for sample in samples:
dropout_condition(sample, condition_type, condition)
return samples
def __repr__(self):
return f"AttributeDropout({dict(self.p)})"
class ClassifierFreeGuidanceDropout(DropoutModule):
"""Classifier Free Guidance dropout.
All attributes are dropped with the same probability.
Args:
p (float): Probability to apply condition dropout during training.
seed (int): Random seed.
"""
def __init__(self, p: float, seed: int = 1234):
super().__init__(seed=seed)
self.p = p
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
"""
Args:
samples (list[ConditioningAttributes]): List of conditions.
Returns:
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
"""
if not self.training:
return samples
# decide on which attributes to drop in a batched fashion
drop = torch.rand(1, generator=self.rng).item() < self.p
if not drop:
return samples
# nullify conditions of all attributes
samples = deepcopy(samples)
for condition_type in ["wav", "text"]:
for sample in samples:
for condition in sample.attributes[condition_type]:
dropout_condition(sample, condition_type, condition)
return samples
def __repr__(self):
return f"ClassifierFreeGuidanceDropout(p={self.p})"
class ConditioningProvider(nn.Module):
"""Prepare and provide conditions given all the supported conditioners.
Args:
conditioners (dict): Dictionary of conditioners.
device (torch.device or str, optional): Device for conditioners and output condition types.
"""
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
super().__init__()
self.device = device
self.conditioners = nn.ModuleDict(conditioners)
@property
def joint_embed_conditions(self):
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
@property
def has_joint_embed_conditions(self):
return len(self.joint_embed_conditions) > 0
@property
def text_conditions(self):
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
@property
def wav_conditions(self):
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
@property
def has_wav_condition(self):
return len(self.wav_conditions) > 0
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
This should be called before starting any real GPU work to avoid synchronization points.
This will return a dict matching conditioner names to their arbitrary tokenized representations.
Args:
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
text and wav conditions.
"""
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
f" but types were {set([type(x) for x in inputs])}"
)
output = {}
text = self._collate_text(inputs)
wavs = self._collate_wavs(inputs)
joint_embeds = self._collate_joint_embeds(inputs)
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
)
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
output[attribute] = self.conditioners[attribute].tokenize(batch)
return output
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
The output is for example:
{
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
...
}
Args:
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
"""
output = {}
for attribute, inputs in tokenized.items():
condition, mask = self.conditioners[attribute](inputs)
output[attribute] = (condition, mask)
return output
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
are the attributes and the values are the aggregated input per attribute.
For example:
Input:
[
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
]
Output:
{
"genre": ["Rock", "Hip-hop"],
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
}
Args:
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
Returns:
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
"""
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
texts = [x.text for x in samples]
for text in texts:
for condition in self.text_conditions:
out[condition].append(text[condition])
return out
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
and the values are Tensors of wavs according to said attributes.
*Note*: by the time the samples reach this function, each sample should have some waveform
inside the "wav" attribute. It should be either:
1. A real waveform
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
Args:
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
Returns:
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
"""
wavs = defaultdict(list)
lengths = defaultdict(list)
sample_rates = defaultdict(list)
paths = defaultdict(list)
seek_times = defaultdict(list)
out: tp.Dict[str, WavCondition] = {}
for sample in samples:
for attribute in self.wav_conditions:
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
# mono-channel conditioning
wav = wav.mean(1, keepdim=True) # [1, 1, T]
wavs[attribute].append(wav.flatten()) # [T]
lengths[attribute].append(length)
sample_rates[attribute].extend(sample_rate)
paths[attribute].extend(path)
seek_times[attribute].extend(seek_time)
# stack all wavs to a single tensor
for attribute in self.wav_conditions:
stacked_wav, _ = collate(wavs[attribute], dim=0)
out[attribute] = WavCondition(
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
paths[attribute], seek_times[attribute])
return out
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
Args:
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
Returns:
A dictionary mapping an attribute name to joint embeddings.
"""
texts = defaultdict(list)
wavs = defaultdict(list)
lengths = defaultdict(list)
sample_rates = defaultdict(list)
paths = defaultdict(list)
seek_times = defaultdict(list)
channels: int = 0
out = {}
for sample in samples:
for attribute in self.joint_embed_conditions:
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
assert wav.dim() == 3
if channels == 0:
channels = wav.size(1)
else:
assert channels == wav.size(1), "not all audio has same number of channels in batch"
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
wavs[attribute].append(wav)
texts[attribute].extend(text)
lengths[attribute].append(length)
sample_rates[attribute].extend(sample_rate)
paths[attribute].extend(path)
seek_times[attribute].extend(seek_time)
for attribute in self.joint_embed_conditions:
stacked_texts = texts[attribute]
stacked_paths = paths[attribute]
stacked_seek_times = seek_times[attribute]
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
stacked_sample_rates = sample_rates[attribute]
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
assert stacked_lengths.size(0) == stacked_wavs.size(0)
assert len(stacked_sample_rates) == stacked_wavs.size(0)
assert len(stacked_texts) == stacked_wavs.size(0)
out[attribute] = JointEmbedCondition(
text=stacked_texts, wav=stacked_wavs,
length=stacked_lengths, sample_rate=stacked_sample_rates,
path=stacked_paths, seek_time=stacked_seek_times)
return out
class ConditionFuser(StreamingModule):
"""Condition fuser handles the logic to combine the different conditions
to the actual model input.
Args:
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
each condition. For example:
{
"prepend": ["description"],
"sum": ["genre", "bpm"],
"cross": ["description"],
}
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
"""
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
cross_attention_pos_emb_scale: float = 1.0):
super().__init__()
assert all(
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
self.cross_attention_pos_emb = cross_attention_pos_emb
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
self.cond2fuse: tp.Dict[str, str] = {}
for fuse_method, conditions in fuse2cond.items():
for condition in conditions:
self.cond2fuse[condition] = fuse_method
def forward(
self,
input: torch.Tensor,
conditions: tp.Dict[str, ConditionType]
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
"""Fuse the conditions to the provided model input.
Args:
input (torch.Tensor): Transformer input.
conditions (dict[str, ConditionType]): Dict of conditions.
Returns:
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
after the conditions have been fused. The second output tensor is the tensor
used for cross-attention or None if no cross attention inputs exist.
"""
B, T, _ = input.shape
if 'offsets' in self._streaming_state:
first_step = False
offsets = self._streaming_state['offsets']
else:
first_step = True
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
f"given conditions contain unknown attributes for fuser, " \
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
cross_attention_output = None
for cond_type, (cond, cond_mask) in conditions.items():
op = self.cond2fuse[cond_type]
if op == 'sum':
input += cond
elif op == 'input_interpolate':
cond = einops.rearrange(cond, "b t d -> b d t")
cond = F.interpolate(cond, size=input.shape[1])
input += einops.rearrange(cond, "b d t -> b t d")
elif op == 'prepend':
if first_step:
input = torch.cat([cond, input], dim=1)
elif op == 'cross':
if cross_attention_output is not None:
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
else:
cross_attention_output = cond
else:
raise ValueError(f"unknown op ({op})")
if self.cross_attention_pos_emb and cross_attention_output is not None:
positions = torch.arange(
cross_attention_output.shape[1],
device=cross_attention_output.device
).view(1, -1, 1)
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
if self._is_streaming:
self._streaming_state['offsets'] = offsets + T
return input, cross_attention_output
================================================
FILE: audiocraft/modules/conv.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConv2d(nn.Module):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class NormConvTranspose2d(nn.Module):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class StreamableConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect'):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1"
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).")
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
return self.conv(x)
class StreamableConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs: tp.Dict[str, tp.Any] = {}):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert self.causal or self.trim_right_ratio == 1., \
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
================================================
FILE: audiocraft/modules/diffusion_schedule.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
"""
from collections import namedtuple
import random
import typing as tp
import julius
import torch
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
def betas_from_alpha_bar(alpha_bar):
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
return 1 - alphas
class SampleProcessor(torch.nn.Module):
def project_sample(self, x: torch.Tensor):
"""Project the original sample to the 'space' where the diffusion will happen."""
return x
def return_sample(self, z: torch.Tensor):
"""Project back from diffusion space to the actual sample space."""
return z
class MultiBandProcessor(SampleProcessor):
"""
MultiBand sample processor. The input audio is splitted across
frequency bands evenly distributed in mel-scale.
Each band will be rescaled to match the power distribution
of Gaussian noise in that band, using online metrics
computed on the first few samples.
Args:
n_bands (int): Number of mel-bands to split the signal over.
sample_rate (int): Sample rate of the audio.
num_samples (int): Number of samples to use to fit the rescaling
for each band. The processor won't be stable
until it has seen that many samples.
power_std (float or list/tensor): The rescaling factor computed to match the
power of Gaussian noise in each band is taken to
that power, i.e. `1.` means full correction of the energy
in each band, and values less than `1` means only partial
correction. Can be used to balance the relative importance
of low vs. high freq in typical audio signals.
"""
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
super().__init__()
self.n_bands = n_bands
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
self.num_samples = num_samples
self.power_std = power_std
if isinstance(power_std, list):
assert len(power_std) == n_bands
power_std = torch.tensor(power_std)
self.register_buffer('counts', torch.zeros(1))
self.register_buffer('sum_x', torch.zeros(n_bands))
self.register_buffer('sum_x2', torch.zeros(n_bands))
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
self.counts: torch.Tensor
self.sum_x: torch.Tensor
self.sum_x2: torch.Tensor
self.sum_target_x2: torch.Tensor
@property
def mean(self):
mean = self.sum_x / self.counts
return mean
@property
def std(self):
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
return std
@property
def target_std(self):
target_std = self.sum_target_x2 / self.counts
return target_std
def project_sample(self, x: torch.Tensor):
assert x.dim() == 3
bands = self.split_bands(x)
if self.counts.item() < self.num_samples:
ref_bands = self.split_bands(torch.randn_like(x))
self.counts += len(x)
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
return bands.sum(dim=0)
def return_sample(self, x: torch.Tensor):
assert x.dim() == 3
bands = self.split_bands(x)
rescale = (self.std / self.target_std) ** self.power_std
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
return bands.sum(dim=0)
class NoiseSchedule:
"""Noise schedule for diffusion.
Args:
beta_t0 (float): Variance of the first diffusion step.
beta_t1 (float): Variance of the last diffusion step.
beta_exp (float): Power schedule exponent
num_steps (int): Number of diffusion step.
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
clip (float): clipping value for the denoising steps
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
repartition (str): shape of the schedule only power schedule is supported
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
noise_scale (float): Scaling factor for the noise
"""
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
self.beta_t0 = beta_t0
self.beta_t1 = beta_t1
self.variance = variance
self.num_steps = num_steps
self.clip = clip
self.sample_processor = sample_processor
self.rescale = rescale
self.n_bands = n_bands
self.noise_scale = noise_scale
assert n_bands is None
if repartition == "power":
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
device=device, dtype=torch.float) ** beta_exp
else:
raise RuntimeError('Not implemented')
self.rng = random.Random(1234)
def get_beta(self, step: tp.Union[int, torch.Tensor]):
if self.n_bands is None:
return self.betas[step]
else:
return self.betas[:, step] # [n_bands, len(step)]
def get_initial_noise(self, x: torch.Tensor):
if self.n_bands is None:
return torch.randn_like(x)
return torch.randn((x.size(0), self.n_bands, x.size(2)))
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
if step is None:
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
if type(step) is int:
return (1 - self.betas[:step + 1]).prod()
else:
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
"""Create a noisy data item for diffusion model training:
Args:
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
tensor_step (bool): If tensor_step = false, only one step t is sample,
the whole batch is diffused to the same step and t is int.
If tensor_step = true, t is a tensor of size (x.size(0),)
every element of the batch is diffused to a independently sampled.
"""
step: tp.Union[int, torch.Tensor]
if tensor_step:
bs = x.size(0)
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
else:
step = self.rng.randrange(self.num_steps)
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
x = self.sample_processor.project_sample(x)
noise = torch.randn_like(x)
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
return TrainingItem(noisy, noise, step)
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
"""Full ddpm reverse process.
Args:
model (nn.Module): Diffusion model.
initial (tensor): Initial Noise.
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
return_list (bool): Whether to return the whole process or only the sampled point.
"""
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
current = initial
iterates = [initial]
for step in range(self.num_steps)[::-1]:
with torch.no_grad():
estimate = model(current, step, condition=condition).sample
alpha = 1 - self.betas[step]
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
if step == 0:
sigma2 = 0
elif self.variance == 'beta':
sigma2 = 1 - alpha
elif self.variance == 'beta_tilde':
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
elif self.variance == 'none':
sigma2 = 0
else:
raise ValueError(f'Invalid variance type {self.variance}')
if sigma2 > 0:
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
if self.clip:
previous = previous.clamp(-self.clip, self.clip)
current = previous
alpha_bar = previous_alpha_bar
if step == 0:
previous *= self.rescale
if return_list:
iterates.append(previous.cpu())
if return_list:
return iterates
else:
return self.sample_processor.return_sample(previous)
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
"""Reverse process that only goes through Markov chain states in step_list."""
if step_list is None:
step_list = list(range(1000))[::-50] + [0]
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
current = initial * self.noise_scale
iterates = [current]
for idx, step in enumerate(step_list[:-1]):
with torch.no_grad():
estimate = model(current, step, condition=condition).sample * self.noise_scale
alpha = 1 - betas_subsampled[-1 - idx]
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
if step == step_list[-2]:
sigma2 = 0
previous_alpha_bar = torch.tensor(1.0)
else:
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
if sigma2 > 0:
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
if self.clip:
previous = previous.clamp(-self.clip, self.clip)
current = previous
alpha_bar = previous_alpha_bar
if step == 0:
previous *= self.rescale
if return_list:
iterates.append(previous.cpu())
if return_list:
return iterates
else:
return self.sample_processor.return_sample(previous)
================================================
FILE: audiocraft/modules/lstm.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
class StreamableLSTM(nn.Module):
"""LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
def forward(self, x):
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y
================================================
FILE: audiocraft/modules/rope.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from torch import nn
import torch
class XPos(nn.Module):
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
This applies an exponential decay to the RoPE rotation matrix.
Args:
dim (int): Embedding dimension.
smoothing (float): Smoothing factor applied to the decay rates.
base_scale (int): Base decay rate, given in terms of scaling time.
device (torch.device, optional): Device on which to initialize the module.
dtype (torch.dtype): dtype to use to generate the embedding.
"""
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
device=None, dtype: torch.dtype = torch.float32):
super().__init__()
assert dim % 2 == 0
assert dtype in [torch.float64, torch.float32]
self.dtype = dtype
self.base_scale = base_scale
half_dim = dim // 2
adim = torch.arange(half_dim, device=device, dtype=dtype)
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
self.register_buffer("decay_rates", decay_rates)
self.decay: tp.Optional[torch.Tensor] = None
def get_decay(self, start: int, end: int):
"""Create complex decay tensor, cache values for fast computation."""
if self.decay is None or end > self.decay.shape[0]:
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
power = idx / self.base_scale
scale = self.decay_rates ** power.unsqueeze(-1)
self.decay = torch.polar(scale, torch.zeros_like(scale))
return self.decay[start:end] # [T, C/2]
class RotaryEmbedding(nn.Module):
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
Args:
dim (int): Embedding dimension (twice the number of frequencies).
max_period (float): Maximum period of the rotation frequencies.
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
scale (float): Scale of positional embedding, set to 0 to deactivate.
device (torch.device, optional): Device on which to initialize the module.
dtype (torch.dtype): dtype to use to generate the embedding.
"""
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
super().__init__()
assert dim % 2 == 0
self.scale = scale
assert dtype in [torch.float64, torch.float32]
self.dtype = dtype
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
frequencies = 1.0 / (max_period ** (adim / dim))
self.register_buffer("frequencies", frequencies)
self.rotation: tp.Optional[torch.Tensor] = None
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
def get_rotation(self, start: int, end: int):
"""Create complex rotation tensor, cache values for fast computation."""
if self.rotation is None or end > self.rotation.shape[0]:
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
angles = torch.outer(idx, self.frequencies)
self.rotation = torch.polar(torch.ones_like(angles), angles)
return self.rotation[start:end]
def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
"""Apply rope rotation to query or key tensor."""
T = x.shape[1]
rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
if self.xpos:
decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
else:
decay = 1.0
if invert_decay:
decay = decay ** -1
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
return x_out.type_as(x)
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
""" Apply rope rotation to both query and key tensors.
Supports streaming mode, in which query and key are not expected to have the same shape.
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
query will be [C] (typically C == 1).
Args:
query (torch.Tensor): Query to rotate.
key (torch.Tensor): Key to rotate.
start (int): Start index of the sequence for time offset.
"""
query_timesteps = query.shape[1]
key_timesteps = key.shape[1]
streaming_offset = key_timesteps - query_timesteps
query_out = self.rotate(query, start + streaming_offset)
key_out = self.rotate(key, start, invert_decay=True)
return query_out, key_out
================================================
FILE: audiocraft/modules/seanet.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import numpy as np
import torch.nn as nn
from .conv import StreamableConv1d, StreamableConvTranspose1d
from .lstm import StreamableLSTM
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output.
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection.
"""
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
super().__init__()
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
For the encoder, it corresponds to the N first blocks.
"""
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
disable_norm_outer_blocks: int = 0):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
"Number of blocks for which to disable norm is invalid." \
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
act = getattr(nn, activation)
mult = 1
model: tp.List[nn.Module] = [
StreamableConv1d(channels, mult * n_filters, kernel_size,
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
norm=block_norm, norm_params=norm_params,
activation=activation, activation_params=activation_params,
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
# Add downsampling layers
model += [
act(**activation_params),
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
kernel_size=ratio * 2, stride=ratio,
norm=block_norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
mult *= 2
if lstm:
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
model += [
act(**activation_params),
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class SEANetDecoder(nn.Module):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple.
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
For the decoder, it corresponds to the N last blocks.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
super().__init__()
self.dimension = dimension
self.channels = channels
self.n_filters = n_filters
self.ratios = ratios
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
"Number of blocks for which to disable norm is invalid." \
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
act = getattr(nn, activation)
mult = int(2 ** len(self.ratios))
model: tp.List[nn.Module] = [
StreamableConv1d(dimension, mult * n_filters, kernel_size,
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
if lstm:
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
# Upsample to raw audio scale
for i, ratio in enumerate(self.ratios):
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
# Add upsampling layers
model += [
act(**activation_params),
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
kernel_size=ratio * 2, stride=ratio,
norm=block_norm, norm_kwargs=norm_params,
causal=causal, trim_right_ratio=trim_right_ratio),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
activation=activation, activation_params=activation_params,
norm=block_norm, norm_params=norm_params, causal=causal,
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
mult //= 2
# Add final layers
model += [
act(**activation_params),
StreamableConv1d(n_filters, channels, last_kernel_size,
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
# Add optional final activation to decoder (eg. tanh)
if final_activation is not None:
final_act = getattr(nn, final_activation)
final_activation_params = final_activation_params or {}
model += [
final_act(**final_activation_params)
]
self.model = nn.Sequential(*model)
def forward(self, z):
y = self.model(z)
return y
================================================
FILE: audiocraft/modules/streaming.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Streaming module API that should be implemented by all Streaming components,
"""
from contextlib import contextmanager
import typing as tp
from torch import nn
import torch
State = tp.Dict[str, torch.Tensor]
class StreamingModule(nn.Module):
"""Common API for streaming components.
Each streaming component has a streaming state, which is just a dict[str, Tensor].
By convention, the first dim of each tensor must be the batch size.
Don't use dots in the key names, as this would clash with submodules
(like in state_dict).
If `self._is_streaming` is True, the component should use and remember
the proper state inside `self._streaming_state`.
To set a streaming component in streaming state, use
with module.streaming():
...
This will automatically reset the streaming state when exiting the context manager.
This also automatically propagates to all streaming children module.
Some module might also implement the `StreamingModule.flush` method, although
this one is trickier, as all parents module must be StreamingModule and implement
it as well for it to work properly. See `StreamingSequential` after.
"""
def __init__(self) -> None:
super().__init__()
self._streaming_state: State = {}
self._is_streaming = False
def _apply_named_streaming(self, fn: tp.Any):
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def _set_streaming(self, streaming: bool):
def _set_streaming(name, module):
module._is_streaming = streaming
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self):
"""Context manager to enter streaming mode. Reset streaming state on exit."""
self._set_streaming(True)
try:
yield
finally:
self._set_streaming(False)
self.reset_streaming()
def reset_streaming(self):
"""Reset the streaming state."""
def _reset(name: str, module: StreamingModule):
module._streaming_state.clear()
self._apply_named_streaming(_reset)
def get_streaming_state(self) -> State:
"""Return the streaming state, including that of sub-modules."""
state: State = {}
def _add(name: str, module: StreamingModule):
if name:
name += "."
for key, value in module._streaming_state.items():
state[name + key] = value
self._apply_named_streaming(_add)
return state
def set_streaming_state(self, state: State):
"""Set the streaming state, including that of sub-modules."""
state = dict(state)
def _set(name: str, module: StreamingModule):
if name:
name += "."
module._streaming_state.clear()
for key, value in list(state.items()):
# complexity is not ideal here, but probably fine.
if key.startswith(name):
local_key = key[len(name):]
if '.' not in local_key:
module._streaming_state[local_key] = value
del state[key]
self._apply_named_streaming(_set)
assert len(state) == 0, list(state.keys())
def flush(self, x: tp.Optional[torch.Tensor] = None):
"""Flush any remaining outputs that were waiting for completion.
Typically, for convolutions, this will add the final padding
and process the last buffer.
This should take an optional argument `x`, which will be provided
if a module before this one in the streaming pipeline has already
spitted out a flushed out buffer.
"""
if x is None:
return None
else:
return self(x)
class StreamingSequential(StreamingModule, nn.Sequential):
"""A streaming compatible alternative of `nn.Sequential`.
"""
def flush(self, x: tp.Optional[torch.Tensor] = None):
for module in self:
if isinstance(module, StreamingModule):
x = module.flush(x)
elif x is not None:
x = module(x)
return x
================================================
FILE: audiocraft/modules/transformer.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Transformer model, with streaming support, xformer attention support
and easy causal attention with a potentially finite receptive field.
See `StreamingTransformer` for more information.
Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
"""
import typing as tp
from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from xformers import ops
from .rope import RotaryEmbedding
from .streaming import StreamingModule
_efficient_attention_backend: str = 'torch'
def set_efficient_attention_backend(backend: str = 'torch'):
# Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
global _efficient_attention_backend
assert _efficient_attention_backend in ['xformers', 'torch']
_efficient_attention_backend = backend
def _get_attention_time_dimension() -> int:
if _efficient_attention_backend == 'torch':
return 2
else:
return 1
def _is_profiled() -> bool:
# Return true if we are currently running with a xformers profiler activated.
try:
from xformers.profiler import profiler
except ImportError:
return False
return profiler._Profiler._CURRENT_PROFILER is not None
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
"""Create normalization module for transformer encoder layer.
Args:
norm_type (str): Normalization method.
dim (int): Dimension of the normalized layer.
**kwargs (dict): Additional parameters for normalization layer.
Returns:
nn.Module: Normalization module.
"""
if norm_type == 'layer_norm':
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
Args:
positions (torch.Tensor): LongTensor of positions.
dim (int): Dimension of the embedding.
max_period (float): Maximum period of the cosine/sine functions.
dtype (torch.dtype or str): dtype to use to generate the embedding.
Returns:
torch.Tensor: Sinusoidal positional embedding.
"""
# We aim for BTC format
assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(dtype)
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
if n_rep == 1:
return x
if _efficient_attention_backend == 'torch':
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
else:
bs, slen, n_kv_heads, head_dim = x.shape
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonally the residual outputs close to 0, with a learnt scale.
Args:
channels (int): Number of channels.
init (float): Initial scale.
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
device (torch.device or str, optional): Device on which to initialize the module.
dtype (torch.dtype, optional): dtype to use to initialize the module.
"""
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
device=None, dtype=None):
super().__init__()
self.channel_last = channel_last
self.scale = nn.Parameter(
torch.full((channels,), init,
requires_grad=True, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
if self.channel_last:
return self.scale * x
else:
return self.scale[:, None] * x
class StreamingMultiheadAttention(StreamingModule):
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
Args:
embed_dim (int): Dimension to project to.
num_heads (int): Number of heads.
dropout (float): Dropout level.
bias (bool): Use bias in projections.
causal (bool): Causal mask applied automatically.
past_context (int, optional): Receptive field for the causal mask, infinite if None.
custom (bool): Use custom MHA implementation, for testing / benchmarking.
memory_efficient (bool): Use xformers based memory efficient attention.
attention_as_float32 (bool): Perform the attention as float32
(especially important with memory_efficient as autocast won't do this automatically).
rope (`RotaryEmbedding`, optional): Rope embedding to use.
cross_attention: Should be true when used as a cross attention.
All keys and values must be available at once, streaming is only for the queries.
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
interpret the time steps in the keys relative to those in the queries).
safe_streaming (bool): Bug fix, will go away with xformers update.
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
device (torch.device, optional): Device on which to initialize.
dtype (torch.dtype, optional): dtype to use.
"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
memory_efficient: bool = False, attention_as_float32: bool = False,
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
device=None, dtype=None):
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
if past_context is not None:
assert causal
self.embed_dim = embed_dim
self.causal = causal
self.past_context = past_context
self.memory_efficient = memory_efficient
self.attention_as_float32 = attention_as_float32
self.rope = rope
self.cross_attention = cross_attention
self.safe_streaming = safe_streaming
self.num_heads = num_heads
self.dropout = dropout
self.kv_repeat = kv_repeat
if cross_attention:
assert not causal, "Causal cannot work with cross attention."
assert rope is None, "Rope cannot work with cross attention."
if memory_efficient:
_verify_xformers_memory_efficient_compat()
self.custom = _is_custom(custom, memory_efficient)
if self.custom:
out_dim = embed_dim
assert num_heads % kv_repeat == 0
assert not cross_attention or kv_repeat == 1
num_kv = num_heads // kv_repeat
kv_dim = (embed_dim // num_heads) * num_kv
out_dim += 2 * kv_dim
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
# We try to follow the default PyTorch MHA convention, to easily compare results.
self.in_proj_weight = in_proj.weight
self.in_proj_bias = in_proj.bias
if bias:
self.in_proj_bias.data.zero_() # Following Pytorch convention
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if bias:
self.out_proj.bias.data.zero_()
else:
assert not qk_layer_norm
assert kv_repeat == 1
self.mha = nn.MultiheadAttention(
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
**factory_kwargs)
self.qk_layer_norm = qk_layer_norm
if qk_layer_norm:
assert self.custom
assert kv_repeat == 1
ln_dim = embed_dim
self.q_layer_norm = nn.LayerNorm(ln_dim)
self.k_layer_norm = nn.LayerNorm(ln_dim)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if not self.custom:
# Support compat with regular MHA
keys = [n for n, _ in self.mha.named_parameters()]
for key in keys:
if prefix + key in state_dict:
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
# Return a causal mask, accounting for potentially stored past keys/values
# We actually return a bias for the attention score, as this has the same
# convention both in the builtin MHA in Pytorch, and Xformers functions.
time_dim = _get_attention_time_dimension()
if self.memory_efficient:
from xformers.ops import LowerTriangularMask
if current_steps == 1:
# If we only have one step, then we do not need a mask.
return None
elif 'past_keys' in self._streaming_state:
raise RuntimeError("Not supported at the moment")
else:
# Then we can safely use a lower triangular mask
return LowerTriangularMask()
if self._streaming_state:
past_keys = self._streaming_state['past_keys']
past_steps = past_keys.shape[time_dim]
else:
past_steps = 0
queries_pos = torch.arange(
past_steps, current_steps + past_steps, device=device).view(-1, 1)
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
delta = queries_pos - keys_pos
valid = delta >= 0
if self.past_context is not None:
valid &= (delta <= self.past_context)
return torch.where(
valid,
torch.zeros([], device=device, dtype=dtype),
torch.full([], float('-inf'), device=device, dtype=dtype))
def _complete_kv(self, k, v):
time_dim = _get_attention_time_dimension()
if self.cross_attention:
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
# to the queries only.
return k, v
# Complete the key/value pair using the streaming state.
if self._streaming_state:
pk = self._streaming_state['past_keys']
nk = torch.cat([pk, k], dim=time_dim)
if v is k:
nv = nk
else:
pv = self._streaming_state['past_values']
nv = torch.cat([pv, v], dim=time_dim)
else:
nk = k
nv = v
assert nk.shape[time_dim] == nv.shape[time_dim]
offset = 0
if self.past_context is not None:
offset = max(0, nk.shape[time_dim] - self.past_context)
if self._is_streaming:
self._streaming_state['past_keys'] = nk[:, offset:]
if v is not k:
self._streaming_state['past_values'] = nv[:, offset:]
if 'offset' in self._streaming_state:
self._streaming_state['offset'] += offset
else:
self._streaming_state['offset'] = torch.tensor(0)
return nk, nv
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
# TODO: fix and verify layout.
assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn."
# Apply rope embeddings to query and key tensors.
assert self.rope is not None
if 'past_keys' in self._streaming_state:
past_keys_offset = self._streaming_state['past_keys'].shape[1]
else:
past_keys_offset = 0
if 'offset' in self._streaming_state:
past_context_offset = int(self._streaming_state['offset'].item())
else:
past_context_offset = 0
streaming_offset = past_context_offset + past_keys_offset
return self.rope.rotate_qk(query, key, start=streaming_offset)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask=None, need_weights=False, attn_mask=None,
average_attn_weights=True, is_causal=False):
assert attn_mask is None
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
"use the causal args in the constructor.")
time_dim = _get_attention_time_dimension()
if time_dim == 2:
layout = "b h t d"
else:
layout = "b t h d"
dtype = query.dtype
if self._is_streaming:
assert self.causal or self.cross_attention, \
"Streaming only available for causal or cross attention"
if self.causal:
# At the moment we specialize only for the self-attention case.
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
if self.custom:
# custom implementation
assert need_weights is False
assert key_padding_mask is None
if self.cross_attention:
# Different queries, keys, values, we have to spit manually the weights
# before applying the linear.
dim = self.in_proj_weight.shape[0] // 3
if self.in_proj_bias is None:
bias_q, bias_k, bias_v = None, None, None
else:
bias_q = self.in_proj_bias[:dim]
bias_k = self.in_proj_bias[dim: 2 * dim]
bias_v = self.in_proj_bias[2 * dim:]
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
# todo: when streaming, we could actually save k, v and check the shape actually match.
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
if self.qk_layer_norm is True:
q = self.q_layer_norm(q)
k = self.k_layer_norm(k)
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
else:
if not _is_profiled():
# profiling breaks that propertysomehow.
assert query is key, "specialized implementation"
assert value is key, "specialized implementation"
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
if self.kv_repeat == 1:
if time_dim == 2:
bound_layout = "b h p t d"
else:
bound_layout = "b t p h d"
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
q, k, v = ops.unbind(packed, dim=2)
else:
embed_dim = self.embed_dim
per_head_dim = (embed_dim // self.num_heads)
kv_heads = self.num_heads // self.kv_repeat
q = projected[:, :, :embed_dim]
start = embed_dim
end = start + per_head_dim * kv_heads
k = projected[:, :, start: end]
v = projected[:, :, end:]
q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
if self.qk_layer_norm is True:
assert self.kv_repeat == 1
q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
q = self.q_layer_norm(q)
k = self.k_layer_norm(k)
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
if self.rope:
q, k = self._apply_rope(q, k)
k, v = self._complete_kv(k, v)
if self.kv_repeat > 1:
k = expand_repeated_kv(k, self.kv_repeat)
v = expand_repeated_kv(v, self.kv_repeat)
if self.attention_as_float32:
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
p = self.dropout if self.training else 0
if _efficient_attention_backend == 'torch':
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
else:
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
else:
# We include the dot product as float32, for consistency
# with the other implementations that include that step
# as part of the attention. Note that when using `autocast`,
# the einsums would be done as bfloat16, but the softmax
# would be done as bfloat16, so `attention_as_float32` will
# extend a bit the range of operations done in float32,
# although this should make no difference.
q = q / q.shape[-1] ** 0.5
key_layout = layout.replace('t', 'k')
query_layout = layout
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
with torch.autocast(device_type=q.device.type, dtype=torch.float32):
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
else:
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
if attn_mask is not None:
pre_w = pre_w + attn_mask
w = torch.softmax(pre_w, dim=-1)
w = F.dropout(w, self.dropout, training=self.training).to(v)
# Key and value have the same format.
x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
x = x.to(dtype)
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
x = self.out_proj(x)
else:
key, value = self._complete_kv(key, value)
if self.attention_as_float32:
query, key, value = [x.float() for x in [query, key, value]]
x, _ = self.mha(
query, key, value, key_padding_mask,
need_weights, attn_mask, average_attn_weights)
x = x.to(dtype)
return x, None
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
"""TransformerLayer with Streaming / Causal support.
This also integrates cross_attention, when passing `cross_attention=True`,
rather than having two separate classes like in PyTorch.
Args:
d_model (int): Dimension of the data.
num_heads (int): Number of heads.
dim_feedforward (int): Intermediate dimension of FF module.
dropout (float): Dropout both for MHA and FF.
bias_ff (bool): Use bias for FF.
bias_attn (bool): Use bias for MHA.
causal (bool): Causal mask applied automatically.
past_context (int, optional): Receptive field for the causal mask, infinite if None.
custom (bool): Use custom MHA implementation, for testing / benchmarking.
memory_efficient (bool): Use xformers based memory efficient attention.
attention_as_float32 (bool): Perform the attention as float32
(especially important with memory_efficient as autocast won't do this automatically).
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
qk_layer_norm_cross (bool): Same for the cross attention.
cross_attention (bool): If True, expect to get secondary input for cross-attention.
Cross attention will use the default MHA, as it typically won't require
special treatment.
layer_scale (float, optional): If not None, LayerScale will be used with
the given value as initial scale.
rope (`RotaryEmbedding`, optional): Rope embedding to use.
attention_dropout (float, optional): If not None, separate the value of the dimension dropout
in FFN and of the attention dropout.
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
device (torch.device, optional): Device on which to initialize.
dtype (torch.dtype, optional): dtype to use.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
past_context: tp.Optional[int] = None, custom: bool = False,
memory_efficient: bool = False, attention_as_float32: bool = False,
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
super().__init__(d_model, num_heads, dim_feedforward, dropout,
device=device, dtype=dtype, batch_first=True, **kwargs)
factory_kwargs = {'device': device, 'dtype': dtype}
# Redefine self_attn to our streaming multi-head attention
attn_kwargs: tp.Dict[str, tp.Any] = {
'embed_dim': d_model,
'num_heads': num_heads,
'dropout': dropout if attention_dropout is None else attention_dropout,
'bias': bias_attn,
'custom': custom,
'memory_efficient': memory_efficient,
'attention_as_float32': attention_as_float32,
}
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
# Redefine feedforward layers to expose bias parameter
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
self.layer_scale_1: nn.Module
self.layer_scale_2: nn.Module
if layer_scale is None:
self.layer_scale_1 = nn.Identity()
self.layer_scale_2 = nn.Identity()
else:
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
self.cross_attention: tp.Optional[nn.Module] = None
if cross_attention:
self.cross_attention = StreamingMultiheadAttention(
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
**attn_kwargs, **factory_kwargs)
# Norm and dropout
self.dropout_cross = nn.Dropout(dropout)
# eps value matching that used in PyTorch reference implementation.
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
self.layer_scale_cross: nn.Module
if layer_scale is None:
self.layer_scale_cross = nn.Identity()
else:
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
def _cross_attention_block(self, src: torch.Tensor,
cross_attention_src: torch.Tensor) -> torch.Tensor:
assert self.cross_attention is not None
# queries are from src, keys and values from cross_attention_src.
x = self.cross_attention(
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
return self.dropout_cross(x) # type: ignore
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
src_key_padding_mask: tp.Optional[torch.Tensor] = None,
cross_attention_src: tp.Optional[torch.Tensor] = None):
if self.cross_attention is None:
assert cross_attention_src is None
else:
assert cross_attention_src is not None
x = src
if self.norm_first:
x = x + self.layer_scale_1(
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
if cross_attention_src is not None:
x = x + self.layer_scale_cross(
self._cross_attention_block(
self.norm_cross(x), cross_attention_src))
x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
else:
x = self.norm1(x + self.layer_scale_1(
self._sa_block(x, src_mask, src_key_padding_mask)))
if cross_attention_src is not None:
x = self.norm_cross(
x + self.layer_scale_cross(
self._cross_attention_block(src, cross_attention_src)))
x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
return x
class StreamingTransformer(StreamingModule):
"""Transformer with Streaming / Causal support.
Args:
d_model (int): Dimension of the data.
num_heads (int): Number of heads.
dim_feedforward (int): Intermediate dimension of FF module.
dropout (float): Dropout both for MHA and FF.
bias_ff (bool): Use bias for FF.
bias_attn (bool): Use bias for MHA.
causal (bool): Causal mask applied automatically.
past_context (int, optional): Receptive field for the causal mask, infinite if None.
custom (bool): Use custom MHA implementation, for testing / benchmarking.
memory_efficient (bool): Use xformers based memory efficient attention.
attention_as_float32 (bool): Perform the attention as float32
(especially important with memory_efficient as autocast won't do this automatically).
cross_attention (bool): If True, expect to get secondary input for cross-attention.
layer_scale (float, optional): If not None, LayerScale will be used
with the given value as initial scale.
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
max_period (float): Maximum period of the time embedding.
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
lr (float, optional): learning rate override through the `make_optim_group` API.
weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
layer_class: (subclass of `StreamingTransformerLayer): class to use
to initialize the layers, allowing further customization outside of AudioCraft.
checkpointing (str): Checkpointing strategy to reduce memory usage.
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
a policy for opting-out some operations of the checkpointing like
linear layers and attention, providing a middle ground between speed and memory.
device (torch.device, optional): Device on which to initialize.
dtype (torch.dtype, optional): dtype to use.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
causal: bool = False, past_context: tp.Optional[int] = None,
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
checkpointing: str = 'none', device=None, dtype=None, **kwargs):
super().__init__()
assert d_model % num_heads == 0
self.positional_embedding = positional_embedding
self.max_period = max_period
self.positional_scale = positional_scale
self.weight_decay = weight_decay
self.lr = lr
assert positional_embedding in ['sin', 'rope', 'sin_rope']
self.rope: tp.Optional[RotaryEmbedding] = None
if self.positional_embedding in ['rope', 'sin_rope']:
assert _is_custom(custom, memory_efficient)
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
xpos=xpos, scale=positional_scale, device=device)
self.checkpointing = checkpointing
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
if self.checkpointing.startswith('xformers'):
_verify_xformers_internal_compat()
self.layers = nn.ModuleList()
for idx in range(num_layers):
self.layers.append(
layer_class(
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
causal=causal, past_context=past_context, custom=custom,
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
device=device, dtype=dtype, **kwargs))
if self.checkpointing != 'none':
for layer in self.layers:
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
# backward hook inside of FSDP...
layer._magma_checkpointed = True # type: ignore
assert layer.layer_drop == 0., "Need further checking" # type: ignore
def _apply_layer(self, layer, *args, **kwargs):
method = self.checkpointing
if method == 'none':
return layer(*args, **kwargs)
elif method == 'torch':
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
elif method.startswith('xformers'):
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
if method == 'xformers_default':
# those operations will be saved, and not recomputed.
# According to Francisco we can get smarter policies but this is a good start.
allow_list = [
"xformers.efficient_attention_forward_cutlass.default",
"xformers_flash.flash_fwd.default",
"aten.addmm.default",
"aten.mm.default",
]
elif method == 'xformers_mm':
# those operations will be saved, and not recomputed.
# According to Francisco we can get smarter policies but this is a good start.
allow_list = [
"aten.addmm.default",
"aten.mm.default",
]
else:
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
policy_fn = _get_default_policy(allow_list)
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
else:
raise ValueError(f"Checkpointing method {method} is unknown.")
def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
if 'offsets' in self._streaming_state:
offsets = self._streaming_state['offsets']
else:
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
if self.positional_embedding in ['sin', 'sin_rope']:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
positions = positions + offsets.view(-1, 1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
x = x + self.positional_scale * pos_emb
for layer in self.layers:
x = self._apply_layer(layer, x, *args, **kwargs)
if self._is_streaming:
self._streaming_state['offsets'] = offsets + T
return x
def make_optim_group(self):
group = {"params": list(self.parameters())}
if self.lr is not None:
group["lr"] = self.lr
if self.weight_decay is not None:
group["weight_decay"] = self.weight_decay
return group
# special attention related function
def _verify_xformers_memory_efficient_compat():
try:
from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
except ImportError:
raise ImportError(
"xformers is not installed. Please install it and try again.\n"
"To install on AWS and Azure, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
"To install on FAIR Cluster, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
def _verify_xformers_internal_compat():
try:
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
except ImportError:
raise ImportError(
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
"To install on AWS and Azure, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
"To install on FAIR Cluster, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
def _is_custom(custom: bool, memory_efficient: bool):
return custom or memory_efficient
================================================
FILE: audiocraft/optim/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers
and Exponential Moving Average.
"""
# flake8: noqa
from .cosine_lr_scheduler import CosineLRScheduler
from .dadam import DAdaptAdam
from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler
from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler
from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler
from .ema import ModuleDictEMA
================================================
FILE: audiocraft/optim/cosine_lr_scheduler.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class CosineLRScheduler(_LRScheduler):
"""Cosine LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
lr_min_ratio (float): Minimum learning rate.
cycle_length (float): Cycle length.
"""
def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
self.warmup_steps = warmup_steps
assert self.warmup_steps >= 0
self.total_steps = total_steps
assert self.total_steps >= 0
self.lr_min_ratio = lr_min_ratio
self.cycle_length = cycle_length
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
lr_ratio = step / self.warmup_steps
lr = lr_ratio * lr
elif step <= self.total_steps:
s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
(1. + math.cos(math.pi * s / self.cycle_length))
lr = lr_ratio * lr
else:
lr_ratio = self.lr_min_ratio
lr = lr_ratio * lr
return lr
def get_lr(self):
return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
================================================
FILE: audiocraft/optim/dadam.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import TYPE_CHECKING, Any
import torch
import torch.optim
import torch.distributed as dist
if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any
logger = logging.getLogger(__name__)
def to_real(x):
if torch.is_complex(x):
return x.real
else:
return x
class DAdaptAdam(torch.optim.Optimizer):
"""Adam with D-Adaptation automatic step-sizes.
Leave LR set to 1 unless you encounter instability.
Args:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
betas (tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
momentum (float):
Momentum value in the range [0,1) (default: 0.9).
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
log_every (int):
Log using print every k steps, default 0 (no logging).
decouple (boolean):
Use AdamW style decoupled weight decay
d0 (float):
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
growth_rate (float):
prevent the D estimate from growing faster than this multiplicative rate.
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
rate warmup effect.
fsdp_in_use (bool):
If you're using sharded parameters, this should be set to True. The optimizer
will attempt to auto-detect this, but if you're using an implementation other
than PyTorch's builtin version, the auto-detection won't work.
"""
def __init__(self, params, lr=1.0,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
log_every=0,
decouple=True,
d0=1e-6,
growth_rate=float('inf')):
if not 0.0 < d0:
raise ValueError("Invalid d0 value: {}".format(d0))
if not 0.0 < lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 < eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if decouple:
logger.info("Using decoupled weight decay")
from .fsdp import is_fsdp_used
fsdp_in_use = is_fsdp_used()
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
d=d0,
k=0,
gsq_weighted=0.0,
log_every=log_every,
decouple=decouple,
growth_rate=growth_rate,
fsdp_in_use=fsdp_in_use)
super().__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return False
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
g_sq = 0.0
sksq_weighted = 0.0
sk_l1 = 0.0
lr = max(group['lr'] for group in self.param_groups)
group = self.param_groups[0]
gsq_weighted = group['gsq_weighted']
d = group['d']
dlr = d*lr
growth_rate = group['growth_rate']
decouple = group['decouple']
fsdp_in_use = group['fsdp_in_use']
log_every = group['log_every']
beta1, beta2 = group['betas']
for group in self.param_groups:
group_lr = group['lr']
decay = group['weight_decay']
k = group['k']
eps = group['eps']
if group_lr not in [lr, 0.0]:
raise RuntimeError("Setting different lr values in different parameter "
"groups is only supported for values of 0")
for p in group['params']:
if p.grad is None:
continue
if hasattr(p, "_fsdp_flattened"):
fsdp_in_use = True
grad = p.grad.data
# Apply weight decay (coupled variant)
if decay != 0 and not decouple:
grad.add_(p.data, alpha=decay)
state = self.state[p]
# State initialization
if 'step' not in state:
state['step'] = 0
state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
to_real(p.data), memory_format=torch.preserve_format).detach()
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
grad_grad = to_real(grad * grad.conj())
# Adam EMA updates
if group_lr > 0:
exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1))
exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2)
denom = exp_avg_sq.sqrt().add_(eps)
g_sq += grad_grad.div_(denom).sum().item()
s = state['s']
s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2))
sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item()
sk_l1 += s.abs().sum().item()
######
gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2)
d_hat = d
# if we have not done any progres, return
# if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
if sk_l1 == 0:
return loss
if lr > 0.0:
if fsdp_in_use:
dist_tensor = torch.zeros(3, device='cuda')
dist_tensor[0] = sksq_weighted
dist_tensor[1] = gsq_weighted
dist_tensor[2] = sk_l1
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
global_sksq_weighted = dist_tensor[0]
global_gsq_weighted = dist_tensor[1]
global_sk_l1 = dist_tensor[2]
else:
global_sksq_weighted = sksq_weighted
global_gsq_weighted = gsq_weighted
global_sk_l1 = sk_l1
d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1
d = max(d, min(d_hat, d*growth_rate))
if log_every > 0 and k % log_every == 0:
logger.info(
f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. "
f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} "
f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}")
for group in self.param_groups:
group['gsq_weighted'] = gsq_weighted
group['d'] = d
group_lr = group['lr']
decay = group['weight_decay']
k = group['k']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
denom = exp_avg_sq.sqrt().add_(eps)
denom = denom.type(p.type())
# Apply weight decay (decoupled variant)
if decay != 0 and decouple and group_lr > 0:
p.data.add_(p.data, alpha=-decay * dlr)
# Take step
p.data.addcdiv_(exp_avg, denom, value=-1)
group['k'] = k + 1
return loss
================================================
FILE: audiocraft/optim/ema.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# ModelEMA implementation is taken from
# https://github.com/facebookresearch/demucs
from collections import defaultdict
import typing as tp
import torch
import torch.nn as nn
def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set:
names: set = set()
for (name, sub_module) in module.named_modules():
if name == '':
buffer_names = module._non_persistent_buffers_set
buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name
for buff_name in buffer_names}
names.update(buffer_names)
else:
sub_name = f"{root}.{name}" if len(root) > 0 else name
sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name)
names.update(sub_buffer_names)
return names
def _get_named_tensors(module: nn.Module):
non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module)
named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers()
if name not in non_persistent_buffers_set]
named_parameters = list(module.named_parameters())
return named_parameters + named_buffers
class ModuleDictEMA:
"""Exponential Moving Average over a nn.ModuleDict.
You can switch to the EMA weights temporarily.
"""
def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
self.decay = decay
self.module_dict = module_dict
self.state: dict = defaultdict(dict)
self.count = 0
self.device = device
self.unbias = unbias
self._init()
def _init(self):
for module_name, module in self.module_dict.items():
for key, val in _get_named_tensors(module):
if not val.is_floating_point():
continue
device = self.device or val.device
if key not in self.state[module_name]:
self.state[module_name][key] = val.detach().to(device, copy=True)
def step(self):
if self.unbias:
self.count = self.count * self.decay + 1
w = 1 / self.count
else:
w = 1 - self.decay
for module_name, module in self.module_dict.items():
for key, val in _get_named_tensors(module):
if not val.is_floating_point():
continue
device = self.device or val.device
self.state[module_name][key].mul_(1 - w)
self.state[module_name][key].add_(val.detach().to(device), alpha=w)
def state_dict(self):
return {'state': self.state, 'count': self.count}
def load_state_dict(self, state):
self.count = state['count']
for module_name, module in state['state'].items():
for key, val in module.items():
self.state[module_name][key].copy_(val)
================================================
FILE: audiocraft/optim/fsdp.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around FSDP for more convenient use in the training loops.
"""
from contextlib import contextmanager
import typing as tp
import dora
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType)
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
def is_fsdp_used() -> bool:
"""Return whether we are using FSDP."""
# A bit of a hack but should work from anywhere.
if dora.is_xp():
cfg = dora.get_xp().cfg
if hasattr(cfg, 'fsdp'):
return cfg.fsdp.use
return False
def is_sharded_tensor(x: tp.Any) -> bool:
return isinstance(x, ShardedTensor)
@contextmanager
def switch_to_full_state_dict(models: tp.List[FSDP]):
# Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
# so let's do thing manually.
for model in models:
FSDP.set_state_dict_type( # type: ignore
model, StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True))
try:
yield
finally:
for model in models:
FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore
def wrap_with_fsdp(cfg, model: torch.nn.Module,
block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP:
"""Wraps a model with FSDP."""
# Some of the typing is disabled until this gets integrated
# into the stable version of PyTorch.
from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore
# we import this here to prevent circular import.
from ..modules.transformer import StreamingTransformerLayer
from ..modules.conditioners import ConditioningProvider
_fix_post_backward_hook()
assert cfg.use
sharding_strategy_dict = {
"no_shard": ShardingStrategy.NO_SHARD,
"shard_grad_op": ShardingStrategy.SHARD_GRAD_OP,
"full_shard": ShardingStrategy.FULL_SHARD,
}
dtype_dict = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
mixed_precision_config = MixedPrecision(
param_dtype=dtype_dict[cfg.param_dtype],
reduce_dtype=dtype_dict[cfg.reduce_dtype],
buffer_dtype=dtype_dict[cfg.buffer_dtype],
)
sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy]
# The following is going to require being a bit smart
# when doing LM, because this would flush the weights for every time step
# during generation. One possiblity is to use hybrid sharding:
# See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy
assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
"Not supported at the moment, requires a bit more work."
local_rank = dora.distrib.get_distrib_spec().local_rank
assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!"
auto_wrap_policy = None
if block_classes is None:
block_classes = {StreamingTransformerLayer, ConditioningProvider}
if cfg.per_block:
auto_wrap_policy = ModuleWrapPolicy(block_classes)
wrapped = _FSDPFixStateDict(
model,
sharding_strategy=sharding_strategy_config,
mixed_precision=mixed_precision_config,
device_id=local_rank,
sync_module_states=True,
use_orig_params=True,
auto_wrap_policy=auto_wrap_policy,
) # type: ignore
FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore
# Let the wrapped model know about the wrapping!
# We use __dict__ to avoid it going into the state dict.
# This is a bit dirty, but needed during generation, as otherwise
# the wrapped model would call itself and bypass FSDP.
for module in FSDP.fsdp_modules(wrapped):
original = module._fsdp_wrapped_module
original.__dict__['_fsdp'] = module
return wrapped
def purge_fsdp(model: FSDP):
"""Purge the FSDP cached shard inside the model. This should
allow setting the best state or switching to the EMA.
"""
from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore
for module in FSDP.fsdp_modules(model):
handles = module._handles
if not handles:
continue
handle = handles[0]
unsharded_flat_param = handle._get_padded_unsharded_flat_param()
storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore
if storage_size == 0:
continue
true_list = [True for h in handles]
_reshard(module, handles, true_list)
class _FSDPFixStateDict(FSDP):
@staticmethod
def _name_without_fsdp_prefix(name: str) -> str:
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore
parts = name.split('.')
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
return '.'.join(new_parts)
def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore
state = dict(super().state_dict())
for key, value in list(state.items()):
if is_sharded_tensor(value):
del state[key]
return state
def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore
if self._state_dict_type is StateDictType.FULL_STATE_DICT:
super().load_state_dict(state)
purge_fsdp(self)
return
# Fix FSDP load state dict in all situation.
# Use this only with LOCAL_STATE_DICT !!!
current_state = dict(super().state_dict())
for key, value in state.items():
key = _FSDPFixStateDict._name_without_fsdp_prefix(key)
if key not in current_state:
# Emulate strict loading manually.
raise RuntimeError(f"Unknown state key {key}")
current_state[key].copy_(value)
# Purging cached weights from previous forward.
purge_fsdp(self)
_hook_fixed = False
def _fix_post_backward_hook():
global _hook_fixed
if _hook_fixed:
return
_hook_fixed = True
from torch.distributed.fsdp import _runtime_utils
from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState
old_hook = _runtime_utils._post_backward_hook
def _post_backward_hook(state, handle, *args, **kwargs):
checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False)
if checkpointed:
# there will be one more forward in the backward with checkpointing and that will
# massively confuse FSDP, so we have to make it think everything
# is going according to the plan.
state.training_state = TrainingState.FORWARD_BACKWARD
handle._training_state = HandleTrainingState.BACKWARD_PRE
old_hook(state, handle, *args, **kwargs)
_runtime_utils._post_backward_hook = _post_backward_hook
================================================
FILE: audiocraft/optim/inverse_sqrt_lr_scheduler.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class InverseSquareRootLRScheduler(_LRScheduler):
"""Inverse square root LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
warmup_init_lr (tp.Optional[float]): Initial learning rate
during warmup phase. When not set, use the provided learning rate.
"""
def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
self.warmup_steps = warmup_steps
self.warmup_init_lr = warmup_init_lr
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
warmup_init_lr = self.warmup_init_lr or 0
lr_step = (lr - warmup_init_lr) / self.warmup_steps
lr = warmup_init_lr + step * lr_step
else:
decay_factor = lr * self.warmup_steps**0.5
lr = decay_factor * step**-0.5
return lr
def get_lr(self):
return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
================================================
FILE: audiocraft/optim/linear_warmup_lr_scheduler.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class LinearWarmupLRScheduler(_LRScheduler):
"""Inverse square root LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
warmup_init_lr (tp.Optional[float]): Initial learning rate
during warmup phase. When not set, use the provided learning rate.
"""
def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
self.warmup_steps = warmup_steps
self.warmup_init_lr = warmup_init_lr
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if step < self.warmup_steps:
warmup_init_lr = self.warmup_init_lr or 0
lr_step = (lr - warmup_init_lr) / self.warmup_steps
lr = warmup_init_lr + step * lr_step
return lr
def get_lr(self):
return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
================================================
FILE: audiocraft/optim/polynomial_decay_lr_scheduler.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
class PolynomialDecayLRScheduler(_LRScheduler):
"""Polynomial decay LR scheduler.
Args:
optimizer (Optimizer): Torch optimizer.
warmup_steps (int): Number of warmup steps.
total_steps (int): Total number of steps.
end_lr (float): Final learning rate to achieve over total number of steps.
zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
power (float): Decay exponent.
"""
def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.end_lr = end_lr
self.zero_lr_warmup_steps = zero_lr_warmup_steps
self.power = power
super().__init__(optimizer)
def _get_sched_lr(self, lr: float, step: int):
if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
lr = 0
elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
lr = lr_ratio * lr
elif step >= self.total_steps:
lr = self.end_lr
else:
total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
lr_range = lr - self.end_lr
pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
lr = lr_range * pct_remaining ** self.power + self.end_lr
return lr
def get_lr(self):
return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
================================================
FILE: audiocraft/py.typed
================================================
================================================
FILE: audiocraft/quantization/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""RVQ."""
# flake8: noqa
from .vq import ResidualVectorQuantizer
from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
================================================
FILE: audiocraft/quantization/base.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Base class for all quantizers.
"""
from dataclasses import dataclass, field
import typing as tp
import torch
from torch import nn
@dataclass
class QuantizedResult:
x: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class BaseQuantizer(nn.Module):
"""Base class for quantizers.
"""
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
"""
Given input tensor x, returns first the quantized (or approximately quantized)
representation along with quantized codes, bandwidth, and any penalty term for the loss.
Finally, this returns a dict of metrics to update logging etc.
Frame rate must be passed so that the bandwidth is properly computed.
"""
raise NotImplementedError()
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth."""
raise NotImplementedError()
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation."""
raise NotImplementedError()
@property
def total_codebooks(self):
"""Total number of codebooks."""
raise NotImplementedError()
@property
def num_codebooks(self):
"""Number of active codebooks."""
raise NotImplementedError()
def set_num_codebooks(self, n: int):
"""Set the number of active codebooks."""
raise NotImplementedError()
class DummyQuantizer(BaseQuantizer):
"""Fake quantizer that actually does not perform any quantization.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, frame_rate: int):
q = x.unsqueeze(1)
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
In the case of the DummyQuantizer, the codes are actually identical
to the input and resulting quantized representation as no quantization is done.
"""
return x.unsqueeze(1)
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
In the case of the DummyQuantizer, the codes are actually identical
to the input and resulting quantized representation as no quantization is done.
"""
return codes.squeeze(1)
@property
def total_codebooks(self):
"""Total number of codebooks."""
return 1
@property
def num_codebooks(self):
"""Total number of codebooks."""
return self.total_codebooks
def set_num_codebooks(self, n: int):
"""Set the number of active codebooks."""
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
================================================
FILE: audiocraft/quantization/core_vq.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from einops import rearrange, repeat
import flashy
import torch
from torch import nn, einsum
import torch.nn.functional as F
def exists(val: tp.Optional[tp.Any]) -> bool:
return val is not None
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if exists(val) else d
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
def orthogonal_loss_fn(t):
# eq (2) from https://arxiv.org/abs/2112.00384
n = t.shape[0]
normed_codes = l2norm(t)
identity = torch.eye(n, device=t.device)
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.8,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
flashy.distrib.broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
flashy.distrib.broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int):
channels_last (bool): Channels are the last dimension in the input tensors.
commitment_weight (float): Weight for commitment loss.
orthogonal_reg_weight (float): Orthogonal regularization weights.
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
for orthogonal regularization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.8,
epsilon: float = 1e-5,
kmeans_init: bool = False,
kmeans_iters: int = 10,
threshold_ema_dead_code: int = 2,
channels_last: bool = False,
commitment_weight: float = 1.,
orthogonal_reg_weight: float = 0.0,
orthogonal_reg_active_codes_only: bool = False,
orthogonal_reg_max_codes: tp.Optional[int] = None,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
self.channels_last = channels_last
@property
def codebook(self):
return self._codebook.embed
@property
def inited(self):
return self._codebook.inited
def _preprocess(self, x):
if not self.channels_last:
x = rearrange(x, "b d n -> b n d")
return x
def _postprocess(self, quantize):
if not self.channels_last:
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def encode(self, x):
x = self._preprocess(x)
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = self._postprocess(quantize)
return quantize
def forward(self, x):
device = x.device
x = self._preprocess(x)
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
if self.orthogonal_reg_weight > 0:
codebook = self.codebook
if self.orthogonal_reg_active_codes_only:
# only calculate orthogonal loss for the activated codes for this batch
unique_code_ids = torch.unique(embed_ind)
codebook = codebook[unique_code_ids]
num_codes = codebook.shape[0]
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
codebook = codebook[rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
quantize = self.project_out(quantize)
quantize = self._postprocess(quantize)
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out
================================================
FILE: audiocraft/quantization/vq.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import torch
from .base import BaseQuantizer, QuantizedResult
from .core_vq import ResidualVectorQuantization
class ResidualVectorQuantizer(BaseQuantizer):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
q_dropout (bool): Random quantizer drop out at train time.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
orthogonal_reg_weight (float): Orthogonal regularization weights.
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
for orthogonal regularization.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
q_dropout: bool = False,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 10,
threshold_ema_dead_code: int = 2,
orthogonal_reg_weight: float = 0.0,
orthogonal_reg_active_codes_only: bool = False,
orthogonal_reg_max_codes: tp.Optional[int] = None,
):
super().__init__()
self.max_n_q = n_q
self.n_q = n_q
self.q_dropout = q_dropout
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
self.vq = ResidualVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
orthogonal_reg_weight=self.orthogonal_reg_weight,
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
channels_last=False
)
def forward(self, x: torch.Tensor, frame_rate: int):
n_q = self.n_q
if self.training and self.q_dropout:
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
bw_per_q = math.log2(self.bins) * frame_rate / 1000
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
"""
n_q = self.n_q
codes = self.vq.encode(x, n_q=n_q)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation."""
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
codes = codes.transpose(0, 1)
quantized = self.vq.decode(codes)
return quantized
@property
def total_codebooks(self):
return self.max_n_q
@property
def num_codebooks(self):
return self.n_q
def set_num_codebooks(self, n: int):
assert n > 0 and n <= self.max_n_q
self.n_q = n
================================================
FILE: audiocraft/solvers/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Solvers. A Solver is a training recipe, combining the dataloaders, models,
optimizer, losses etc into a single convenient object.
"""
# flake8: noqa
from .audiogen import AudioGenSolver
from .builders import get_solver
from .base import StandardSolver
from .compression import CompressionSolver
from .musicgen import MusicGenSolver
from .diffusion import DiffusionSolver
================================================
FILE: audiocraft/solvers/audiogen.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from . import builders, musicgen
class AudioGenSolver(musicgen.MusicGenSolver):
"""Solver for AudioGen re-implementation training task.
Note that this implementation does not strictly follows
the method proposed in https://arxiv.org/abs/2209.15352
but is derived from MusicGen's training pipeline.
More information can be found in the AudioGen model card.
"""
DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
================================================
FILE: audiocraft/solvers/base.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
import typing as tp
import flashy
import omegaconf
import torch
from torch import nn
from .. import optim
from ..optim import fsdp
from ..utils import checkpoint
from ..utils.autocast import TorchAutocast
from ..utils.best_state import BestStateDictManager
from ..utils.deadlock import DeadlockDetect
from ..utils.profiler import Profiler
from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng
class StandardSolver(ABC, flashy.BaseSolver):
"""Standard solver for AudioCraft.
The standard solver implements a base training loop with the following stages:
train, valid, evaluate and generate that are expected to be all defined for
solvers in AudioCraft. It also provides a nice default management of Dora history replay,
checkpoint management across epoch, and logging configuration.
AudioCraft solvers must inherit from the StandardSolver and define the methods
associated to each stage as well as the show, build_model and build_dataloaders methods.
"""
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__()
self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
self.logger.info(f"All XP logs are stored in {self.xp.folder}")
self.cfg = cfg
self.device = cfg.device
self.model: nn.Module
self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
self._fsdp_modules: tp.List[fsdp.FSDP] = []
self._ema_sources: nn.ModuleDict = nn.ModuleDict()
self.ema: tp.Optional[optim.ModuleDictEMA] = None
self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
self._log_updates = self.cfg.logging.get('log_updates', 10)
if self.cfg.logging.log_tensorboard:
self.init_tensorboard(**self.cfg.get('tensorboard'))
if self.cfg.logging.log_wandb and self:
self.init_wandb(**self.cfg.get('wandb'))
# keep a copy of the best performing state for stateful objects
# used for evaluation and generation stages
dtype_best: tp.Optional[torch.dtype] = None
if self.cfg.fsdp.use:
dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore
assert isinstance(dtype_best, torch.dtype)
elif self.cfg.autocast:
dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore
assert isinstance(dtype_best, torch.dtype)
self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
# Hacky support for keeping a copy of the full best state in rank0.
self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict
self._new_best_state: bool = False # should save a new checkpoint
# instantiate datasets and appropriate number of updates per epoch
self.build_dataloaders()
if self.cfg.execute_only is None:
assert 'train' in self.dataloaders, "The train dataset split must be provided."
assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
if self.cfg.optim.updates_per_epoch:
self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
# instantiate model & exponential moving average on the model
self.build_model()
self.logger.info("Model hash: %s", model_hash(self.model))
assert 'model' in self.stateful.sources, \
"Please register the model to stateful with self.register_stateful('model') in build_model."
self.profiler = Profiler(self.model, **self.cfg.profiler)
self.initialize_ema()
self.register_stateful('ema')
assert self.ema is None or 'ema' in self.stateful.sources, \
"Please register the ema to stateful with self.register_stateful('ema') in build_model."
self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
# basic statistics on the trained model
model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
# one copy of grad, one copy of momentum, one copy of denominator and model weights.
# and 4 bytes for each float!
mem_usage = model_size * 4 * 4 / 1000
self.logger.info("Model size: %.2f M params", model_size)
self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)
@property
def autocast(self):
"""Convenient autocast (or not) using the solver configuration."""
return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
def _get_state_source(self, name) -> flashy.state.StateDictSource:
# Internal utility to get a state source from the solver
return self.stateful.sources[name]
@property
def best_metric_name(self) -> tp.Optional[str]:
"""Metric name used to identify the best state. This metric should be stored in the metrics
used on the stage for best state identification (most likely, `valid`). If None, then
no best state is saved.
"""
return None
def register_best_state(self, *args: str):
"""Register state sources in `BestStateDictManager` to keep their best states along with their
latest states. The best state will be used at evaluation stages instead of the latest states.
Shortcut around `BestStateDictManager.register` method. You can pass any number of
attribute, included nested attributes and those will be included into the checkpoints
and automatically restored when `BaseSolver.restore` is called.
"""
for name in args:
state_source = self._get_state_source(name)
assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
self.best_state.register(name, state_source)
def register_ema(self, *args: str):
"""Register state sources for exponential moving average.
The registered sources are used to instantiate a ModuleDictEMA instance.
The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
and swapped with the original state sources with self.swap_ema_state() method.
Usage:
self.register_ema('model')
"""
assert self.ema is None, "Cannot register state source to already instantiated EMA."
for name in args:
self._ema_sources[name] = getattr(self, name)
def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
if isinstance(model, fsdp.FSDP):
self._fsdp_modules.append(model)
return model
def update_best_state_from_stage(self, stage_name: str = 'valid'):
"""Update latest best state based on pending metrics of a given stage. This method relies
on the `BestStateDictManager.update` method to update the best state_dict with latest weights
if the registered states happen to match to the best performing setup.
"""
if self.best_metric_name is None:
# when no best metric is defined, the last state is always the best
self._new_best_state = True
self.logger.info("Updating best state with current state.")
else:
assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
assert self.best_metric_name in self._pending_metrics[stage_name], \
f"Best metric not found in {stage_name} metrics. Cannot register best state"
current_score = self._pending_metrics[stage_name][self.best_metric_name]
all_best_metric_scores = [
past_metrics[stage_name][self.best_metric_name]
for past_metrics in self.history
]
all_best_metric_scores.append(current_score)
best_score = min(all_best_metric_scores)
self._new_best_state = current_score == best_score
if self._new_best_state:
old_best = min(all_best_metric_scores[:-1] + [float('inf')])
self.logger.info(
f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
if self._new_best_state:
if self.cfg.fsdp.use:
# this will give an empty state dict on all ranks but the rank 0
# which will have a copy in memory of the full model.
with fsdp.switch_to_full_state_dict(self._fsdp_modules):
for name in self.best_state.states.keys():
state_source = self._get_state_source(name)
self.best_state.update(name, state_source)
# we save to a different dict.
self.fsdp_best_state.update(self.best_state.state_dict())
# We cannot efficiently load fsdp_best_state when using FSDP,
# so we have do do a second pass, with the local shards.
for name in self.best_state.states.keys():
state_source = self._get_state_source(name)
self.best_state.update(name, state_source)
def _load_new_state_dict(self, state_dict: dict) -> dict:
old_states = {}
for name, new_state in state_dict.items():
state_source = self._get_state_source(name)
old_states[name] = copy_state(state_source.state_dict())
state_source.load_state_dict(new_state)
return old_states
@contextmanager
def swap_best_state(self):
self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
old_states = self._load_new_state_dict(self.best_state.state_dict())
try:
yield
finally:
self.logger.debug("Swapping back from best to original state")
for name, old_state in old_states.items():
state_source = self._get_state_source(name)
state_source.load_state_dict(old_state)
@contextmanager
def swap_ema_state(self):
if self.ema is None:
yield
else:
ema_state_dict = self.ema.state_dict()['state']
self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
old_states = self._load_new_state_dict(ema_state_dict)
try:
yield
finally:
self.logger.debug("Swapping back from EMA state to original state")
for name, old_state in old_states.items():
state_source = self._get_state_source(name)
state_source.load_state_dict(old_state)
@property
def is_training(self):
return self.current_stage == 'train'
def log_model_summary(self, model: nn.Module):
"""Log model summary, architecture and size of the model."""
self.logger.info(model)
mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
self.logger.info("Size: %.1f MB", mb)
@abstractmethod
def build_model(self):
"""Method to implement to initialize model."""
...
def initialize_ema(self):
"""Initialize exponential moving average with the registered sources.
EMA object is created if the optim.ema.model.decay value is non-null.
"""
from .builders import get_ema
self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
if self.ema is None:
self.logger.info('No EMA on the model.')
else:
assert self.cfg.optim.ema.updates > 0
self.logger.info(
f'Initializing EMA on the model with decay = {self.ema.decay}'
f' every {self.cfg.optim.ema.updates} updates'
)
@abstractmethod
def build_dataloaders(self):
"""Method to implement to initialize dataloaders."""
...
@abstractmethod
def show(self):
"""Method to log any information without running the job."""
...
@property
def log_updates(self):
# convenient access to log updates
return self._log_updates
def checkpoint_path(self, **kwargs):
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
return self.folder / checkpoint.checkpoint_name(**kwargs)
def epoch_checkpoint_path(self, epoch: int, **kwargs):
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
def checkpoint_path_with_name(self, name: str, **kwargs):
kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
def save_checkpoints(self):
"""Save checkpoint, optionally keeping a copy for a given epoch."""
is_sharded = self.cfg.fsdp.use
if not flashy.distrib.is_rank_zero() and not is_sharded:
return
self.logger.info("Model hash: %s", model_hash(self.model))
state = self.state_dict()
epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here
# save minimal state_dict as new checkpoint every X epoch
if self.cfg.checkpoint.save_every:
if epoch % self.cfg.checkpoint.save_every == 0:
minimal_state = state
if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
minimal_state = {
name: source for name, source in state.items()
if name in self.cfg.checkpoint.keep_every_states
}
epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
# save checkpoint as latest checkpoint
if self.cfg.checkpoint.save_last:
last_checkpoint_path = self.checkpoint_path()
checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
# flush any stale checkpoint to reduce disk footprint
checkpoint.flush_stale_checkpoints(self.checkpoint_path())
def load_from_pretrained(self, name: str) -> dict:
raise NotImplementedError("Solver does not provide a way to load pretrained models.")
def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
"""Load last checkpoint or the one specified in continue_from.
Args:
load_best (bool): Whether to load from best state dict or not.
Best state dict is always used when not loading the current xp.
ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
Returns:
state (dict, optional): The loaded state dictionary.
"""
# load checkpoints from xp folder or cfg.continue_from
is_sharded = self.cfg.fsdp.use
load_from_path: tp.Optional[Path] = None
checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
if load_best:
self.logger.info("Trying to load state_dict from best state.")
state: tp.Optional[dict] = None
rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
current_checkpoint_path = self.checkpoint_path()
_pretrained_prefix = '//pretrained/'
continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
if rank0_checkpoint_path.exists():
self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
load_from_path = current_checkpoint_path
checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
elif self.cfg.continue_from and not continue_pretrained:
self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
# we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
if load_from_path is None:
self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
checkpoint_source = checkpoint.CheckpointSource.OTHER
if load_from_path is not None:
state = checkpoint.load_checkpoint(load_from_path, is_sharded)
elif continue_pretrained:
self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
load_best = True
# checkpoints are not from the current xp, we only retrieve the best state
if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
assert state is not None
self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
load_best = True
state = {key: state[key] for key in self._continue_best_source_keys if key in state}
# loaded checkpoints are FSDP checkpoints: we're reading the best state
# from FSDP and we drop the regular best_state
if 'fsdp_best_state' in state and state['fsdp_best_state']:
state.pop('best_state', None)
self.logger.info("... Loaded checkpoint has FSDP best state")
# FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
# then we're initializing FSDP best state with the regular best state
elif self.cfg.fsdp.use:
if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
# we swap non-FSDP checkpoints best_state to FSDP-compatible best state
state['fsdp_best_state'] = state.pop('best_state')
self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
if state is not None:
if load_best:
self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
for key in set(ignore_state_keys):
if key in state:
state.pop(key)
has_best_state = 'best_state' in state or 'fsdp_best_state' in state
assert has_best_state, ("Trying to load best state but neither 'best_state'",
" or 'fsdp_best_state' found in checkpoints.")
self.load_state_dict(state)
# for FSDP, let's make extra sure nothing bad happened with out of sync
# checkpoints across workers.
epoch = float(self.epoch)
avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
if avg_epoch != epoch:
raise RuntimeError(
f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
f"but average of epochs is {avg_epoch}, at least one gpu must have a "
"different epoch number.")
# on load_best, properly reinitialize state_dict, best states and ema
# otherwise we load from the current xp and don't alter anything
if load_best:
self.logger.info("Loading state_dict from best state.")
if not self.cfg.fsdp.use and self.fsdp_best_state:
# loading from an FSDP checkpoint but with FSDP deactivated
self.logger.info("... Loading from FSDP best state dict.")
self.best_state.load_state_dict(self.fsdp_best_state)
# if load_best, we permanently override the regular state_dict with the best state
if self.cfg.fsdp.use:
self.logger.info("FSDP is used, loading from FSDP best state.")
with fsdp.switch_to_full_state_dict(self._fsdp_modules):
# this might be really fragile but okay for now.
self.load_state_dict(self.fsdp_best_state)
else:
# we permanently swap the stateful objects to their best state
self._load_new_state_dict(self.best_state.state_dict())
# the EMA modules should also be instantiated with best state.
# the easiest way to do so is to reinitialize a new EMA with best state loaded.
if self.ema is not None:
self.logger.info("Re-initializing EMA from best state")
self.initialize_ema()
if self.cfg.fsdp.use:
self.logger.info("Re-initializing best state after using FSDP best state.")
for name in self.best_state.states.keys():
state_source = self._get_state_source(name)
self.best_state.update(name, state_source)
return state
def restore(self, load_best: bool = False, replay_metrics: bool = False,
ignore_state_keys: tp.List[str] = []) -> bool:
"""Restore the status of a solver for a given xp.
Args:
load_best (bool): if `True`, load the best state from the checkpoint.
replay_metrics (bool): if `True`, logs all the metrics from past epochs.
ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
"""
self.logger.info("Restoring weights and history.")
restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
self.logger.info("Model hash: %s", model_hash(self.model))
if replay_metrics and len(self.history) > 0:
self.logger.info("Replaying past metrics...")
for epoch, stages in enumerate(self.history):
for stage_name, metrics in stages.items():
# We manually log the metrics summary to the result logger
# as we don't want to add them to the pending metrics
self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
formatter=self.get_formatter(stage_name))
return restored_checkpoints is not None
def commit(self, save_checkpoints: bool = True):
"""Commit metrics to dora and save checkpoints at the end of an epoch."""
# we override commit to introduce more complex checkpoint saving behaviors
self.history.append(self._pending_metrics) # This will increase self.epoch
if save_checkpoints:
self.save_checkpoints()
self._start_epoch()
if flashy.distrib.is_rank_zero():
self.xp.link.update_history(self.history)
def run_epoch(self):
"""Run a single epoch with all stages.
Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
Children solvers can extend this method with custom behavior, e.g.:
def run_epoch(self):
... # custom code
super().run_epoch()
... # custom code
"""
self.run_stage('train', self.train)
with torch.no_grad():
with self.swap_ema_state():
self.run_stage('valid', self.valid)
# the best state is updated with EMA states if available
self.update_best_state_from_stage('valid')
with self.swap_best_state():
if self.should_run_stage('evaluate'):
self.run_stage('evaluate', self.evaluate)
if self.should_run_stage('generate'):
self.run_stage('generate', with_rank_rng()(self.generate))
def run(self):
"""Training loop."""
assert len(self.state_dict()) > 0
self.restore(replay_metrics=True) # load checkpoint and replay history
self.log_hyperparams(dict_from_config(self.cfg))
for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
if self.should_stop_training():
return
self.run_epoch()
# Commit will send the metrics to Dora and save checkpoints by default.
self.commit()
def should_stop_training(self) -> bool:
"""Check whether we should stop training or not."""
return self.epoch > self.cfg.optim.epochs
def should_run_stage(self, stage_name) -> bool:
"""Check whether we want to run the specified stages."""
stage_every = self.cfg[stage_name].get('every', None)
is_last_epoch = self.epoch == self.cfg.optim.epochs
is_epoch_every = (stage_every and self.epoch % stage_every == 0)
return is_last_epoch or is_epoch_every
@abstractmethod
def run_step(self, idx: int, batch: tp.Any, metrics: dict):
"""Perform one training or valid step on a given batch."""
...
def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
"""Common logic for train and valid stages."""
self.model.train(self.is_training)
loader = self.dataloaders[dataset_split]
# get a different order for distributed training, otherwise this will get ignored
if flashy.distrib.world_size() > 1 \
and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
loader.sampler.set_epoch(self.epoch)
updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
if self.cfg.benchmark_no_load:
self.logger.warning("Fake loading for benchmarking: re-using first batch")
batch = next(iter(loader))
loader = [batch] * updates_per_epoch # type: ignore
lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
average = flashy.averager() # epoch wise average
instant_average = flashy.averager() # average between two logging
metrics: dict = {}
with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates.
for idx, batch in enumerate(lp):
self.deadlock_detect.update('batch')
if idx >= updates_per_epoch:
break
metrics = {}
metrics = self.run_step(idx, batch, metrics)
self.deadlock_detect.update('step')
# run EMA step
if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
self.logger.debug("EMA model step")
self.ema.step()
self.deadlock_detect.update('ema')
self.profiler.step()
instant_metrics = instant_average(metrics)
if lp.update(**instant_metrics):
instant_average = flashy.averager() # reset averager between two logging
metrics = average(metrics) # epoch wise average
self.deadlock_detect.update('end_batch')
metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
return metrics
def train(self):
"""Train stage."""
return self.common_train_valid('train')
def valid(self):
"""Valid stage."""
return self.common_train_valid('valid')
@abstractmethod
def evaluate(self):
"""Evaluate stage."""
...
@abstractmethod
def generate(self):
"""Generate stage."""
...
def run_one_stage(self, stage_name: str):
"""Run only the specified stage.
This method is useful to only generate samples from a trained experiment
or rerun the validation or evaluation stages.
"""
fn = {
'generate': with_rank_rng()(self.generate),
'evaluate': self.evaluate,
'valid': self.valid,
}
if stage_name not in fn:
raise ValueError(f'Trying to run stage {stage_name} is not supported.')
assert len(self.state_dict()) > 0
self._start_epoch()
with torch.no_grad(), self.swap_best_state():
self.run_stage(stage_name, fn[stage_name])
if not self.cfg.execute_inplace:
self.commit(save_checkpoints=False)
@staticmethod
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
device: tp.Optional[str] = None, autocast: bool = True,
batch_size: tp.Optional[int] = None,
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
**kwargs):
"""Mostly a convenience function around audiocraft.train.get_solver_from_sig,
populating all the proper param, deactivating EMA, FSDP, loading the best state,
basically all you need to get a solver ready to "play" with in single GPU mode
and with minimal memory overhead.
Args:
sig (str): signature to load.
dtype (str or None): potential dtype, as a string, i.e. 'float16'.
device (str or None): potential device, as a string, i.e. 'cuda'.
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
"""
from audiocraft import train
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
our_override_cfg['autocast'] = autocast
if dtype is not None:
our_override_cfg['dtype'] = dtype
if device is not None:
our_override_cfg['device'] = device
if batch_size is not None:
our_override_cfg['dataset'] = {'batch_size': batch_size}
if override_cfg is None:
override_cfg = {}
override_cfg = omegaconf.OmegaConf.merge(
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
solver = train.get_solver_from_sig(
sig, override_cfg=override_cfg,
load_best=True, disable_fsdp=True,
ignore_state_keys=['optimizer', 'ema'], **kwargs)
solver.model.eval()
return solver
================================================
FILE: audiocraft/solvers/builders.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
All the functions to build the relevant solvers and used objects
from the Hydra config.
"""
from enum import Enum
import logging
import typing as tp
import dora
import flashy
import omegaconf
import torch
from torch import nn
from torch.optim import Optimizer
# LRScheduler was renamed in some torch versions
try:
from torch.optim.lr_scheduler import LRScheduler # type: ignore
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from .base import StandardSolver
from .. import adversarial, data, losses, metrics, optim
from ..utils.utils import dict_from_config, get_loader
logger = logging.getLogger(__name__)
class DatasetType(Enum):
AUDIO = "audio"
MUSIC = "music"
SOUND = "sound"
def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
"""Instantiate solver from config."""
from .audiogen import AudioGenSolver
from .compression import CompressionSolver
from .musicgen import MusicGenSolver
from .diffusion import DiffusionSolver
klass = {
'compression': CompressionSolver,
'musicgen': MusicGenSolver,
'audiogen': AudioGenSolver,
'lm': MusicGenSolver, # backward compatibility
'diffusion': DiffusionSolver,
'sound_lm': AudioGenSolver, # backward compatibility
}[cfg.solver]
return klass(cfg) # type: ignore
def get_optim_parameter_groups(model: nn.Module):
"""Create parameter groups for the model using the appropriate method
if defined for each modules, to create the different groups.
Args:
model (nn.Module): torch model
Returns:
List of parameter groups
"""
seen_params: tp.Set[nn.parameter.Parameter] = set()
other_params = []
groups = []
for name, module in model.named_modules():
if hasattr(module, 'make_optim_group'):
group = module.make_optim_group()
params = set(group['params'])
assert params.isdisjoint(seen_params)
seen_params |= set(params)
groups.append(group)
for param in model.parameters():
if param not in seen_params:
other_params.append(param)
groups.insert(0, {'params': other_params})
parameters = groups
return parameters
def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer:
"""Build torch optimizer from config and set of parameters.
Supported optimizers: Adam, AdamW
Args:
params (nn.Module or iterable of torch.Tensor): Parameters to optimize.
cfg (DictConfig): Optimization-related configuration.
Returns:
torch.optim.Optimizer.
"""
if 'optimizer' not in cfg:
if getattr(cfg, 'optim', None) is not None:
raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?")
else:
raise KeyError("Optimizer not found in config.")
parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params
optimizer: torch.optim.Optimizer
if cfg.optimizer == 'adam':
optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam)
elif cfg.optimizer == 'adamw':
optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam)
elif cfg.optimizer == 'dadam':
optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam)
else:
raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
return optimizer
def get_lr_scheduler(optimizer: torch.optim.Optimizer,
cfg: omegaconf.DictConfig,
total_updates: int) -> tp.Optional[LRScheduler]:
"""Build torch learning rate scheduler from config and associated optimizer.
Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler
Args:
optimizer (torch.optim.Optimizer): Optimizer.
cfg (DictConfig): Schedule-related configuration.
total_updates (int): Total number of updates.
Returns:
torch.optim.Optimizer.
"""
if 'lr_scheduler' not in cfg:
raise KeyError("LR Scheduler not found in config")
lr_sched: tp.Optional[LRScheduler] = None
if cfg.lr_scheduler == 'step':
lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step)
elif cfg.lr_scheduler == 'exponential':
lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential)
elif cfg.lr_scheduler == 'cosine':
kwargs = dict_from_config(cfg.cosine)
warmup_steps = kwargs.pop('warmup')
lr_sched = optim.CosineLRScheduler(
optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
elif cfg.lr_scheduler == 'polynomial_decay':
kwargs = dict_from_config(cfg.polynomial_decay)
warmup_steps = kwargs.pop('warmup')
lr_sched = optim.PolynomialDecayLRScheduler(
optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
elif cfg.lr_scheduler == 'inverse_sqrt':
kwargs = dict_from_config(cfg.inverse_sqrt)
warmup_steps = kwargs.pop('warmup')
lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
elif cfg.lr_scheduler == 'linear_warmup':
kwargs = dict_from_config(cfg.linear_warmup)
warmup_steps = kwargs.pop('warmup')
lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
elif cfg.lr_scheduler is not None:
raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
return lr_sched
def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]:
"""Initialize Exponential Moving Average.
Args:
module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA.
cfg (omegaconf.DictConfig): Optim EMA configuration.
Returns:
optim.ModuleDictEMA: EMA version of the ModuleDict.
"""
kw: tp.Dict[str, tp.Any] = dict(cfg)
use = kw.pop('use', False)
decay = kw.pop('decay', None)
device = kw.pop('device', None)
if not use:
return None
if len(module_dict) == 0:
raise ValueError("Trying to build EMA but an empty module_dict source is provided!")
ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device)
return ema_module
def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
"""Instantiate loss from configuration."""
klass = {
'l1': torch.nn.L1Loss,
'l2': torch.nn.MSELoss,
'mel': losses.MelSpectrogramL1Loss,
'mrstft': losses.MRSTFTLoss,
'msspec': losses.MultiScaleMelSpectrogramLoss,
'sisnr': losses.SISNR,
}[loss_name]
kwargs = dict(getattr(cfg, loss_name))
return klass(**kwargs)
def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer:
"""Instantiate loss balancer from configuration for the provided weights."""
kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg)
return losses.Balancer(loss_weights, **kwargs)
def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
"""Initialize adversary from config."""
klass = {
'msd': adversarial.MultiScaleDiscriminator,
'mpd': adversarial.MultiPeriodDiscriminator,
'msstftd': adversarial.MultiScaleSTFTDiscriminator,
}[name]
adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name))
return klass(**adv_cfg)
def get_adversarial_losses(cfg) -> nn.ModuleDict:
"""Initialize dict of adversarial losses from config."""
device = cfg.device
adv_cfg = getattr(cfg, 'adversarial')
adversaries = adv_cfg.get('adversaries', [])
adv_loss_name = adv_cfg['adv_loss']
feat_loss_name = adv_cfg.get('feat_loss')
normalize = adv_cfg.get('normalize', True)
feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None
if feat_loss_name:
assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found."
loss = get_loss(feat_loss_name, cfg)
feat_loss = adversarial.FeatureMatchingLoss(loss, normalize)
loss = adversarial.get_adv_criterion(adv_loss_name)
loss_real = adversarial.get_real_criterion(adv_loss_name)
loss_fake = adversarial.get_fake_criterion(adv_loss_name)
adv_losses = nn.ModuleDict()
for adv_name in adversaries:
adversary = get_adversary(adv_name, cfg).to(device)
optimizer = get_optimizer(adversary.parameters(), cfg.optim)
adv_loss = adversarial.AdversarialLoss(
adversary,
optimizer,
loss=loss,
loss_real=loss_real,
loss_fake=loss_fake,
loss_feat=feat_loss,
normalize=normalize
)
adv_losses[adv_name] = adv_loss
return adv_losses
def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
"""Instantiate ViSQOL metric from config."""
kwargs = dict_from_config(cfg)
return metrics.ViSQOL(**kwargs)
def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric:
"""Instantiate Frechet Audio Distance metric from config."""
kwargs = dict_from_config(cfg.tf)
xp = dora.get_xp()
kwargs['log_folder'] = xp.folder
return metrics.FrechetAudioDistanceMetric(**kwargs)
def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
"""Instantiate KL-Divergence metric from config."""
kld_metrics = {
'passt': metrics.PasstKLDivergenceMetric,
}
klass = kld_metrics[cfg.model]
kwargs = dict_from_config(cfg.get(cfg.model))
return klass(**kwargs)
def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric:
"""Instantiate Text Consistency metric from config."""
text_consistency_metrics = {
'clap': metrics.CLAPTextConsistencyMetric
}
klass = text_consistency_metrics[cfg.model]
kwargs = dict_from_config(cfg.get(cfg.model))
return klass(**kwargs)
def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric:
"""Instantiate Chroma Cosine Similarity metric from config."""
assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric"
kwargs = dict_from_config(cfg.get(cfg.model))
return metrics.ChromaCosineSimilarityMetric(**kwargs)
def get_audio_datasets(cfg: omegaconf.DictConfig,
dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]:
"""Build AudioDataset from configuration.
Args:
cfg (omegaconf.DictConfig): Configuration.
dataset_type: The type of dataset to create.
Returns:
dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split.
"""
dataloaders: dict = {}
sample_rate = cfg.sample_rate
channels = cfg.channels
seed = cfg.seed
max_sample_rate = cfg.datasource.max_sample_rate
max_channels = cfg.datasource.max_channels
assert cfg.dataset is not None, "Could not find dataset definition in config"
dataset_cfg = dict_from_config(cfg.dataset)
splits_cfg: dict = {}
splits_cfg['train'] = dataset_cfg.pop('train')
splits_cfg['valid'] = dataset_cfg.pop('valid')
splits_cfg['evaluate'] = dataset_cfg.pop('evaluate')
splits_cfg['generate'] = dataset_cfg.pop('generate')
execute_only_stage = cfg.get('execute_only', None)
for split, path in cfg.datasource.items():
if not isinstance(path, str):
continue # skipping this as not a path
if execute_only_stage is not None and split != execute_only_stage:
continue
logger.info(f"Loading audio data split {split}: {str(path)}")
assert (
cfg.sample_rate <= max_sample_rate
), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found."
assert (
cfg.channels <= max_channels
), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found."
split_cfg = splits_cfg[split]
split_kwargs = {k: v for k, v in split_cfg.items()}
kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg
kwargs['sample_rate'] = sample_rate
kwargs['channels'] = channels
if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch:
kwargs['num_samples'] = (
flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch)
num_samples = kwargs['num_samples']
shuffle = kwargs['shuffle']
return_info = kwargs.pop('return_info')
batch_size = kwargs.pop('batch_size', None)
num_workers = kwargs.pop('num_workers')
if dataset_type == DatasetType.MUSIC:
dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs)
elif dataset_type == DatasetType.SOUND:
dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs)
elif dataset_type == DatasetType.AUDIO:
dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs)
else:
raise ValueError(f"Dataset type is unsupported: {dataset_type}")
loader = get_loader(
dataset,
num_samples,
batch_size=batch_size,
num_workers=num_workers,
seed=seed,
collate_fn=dataset.collater if return_info else None,
shuffle=shuffle,
)
dataloaders[split] = loader
return dataloaders
================================================
FILE: audiocraft/solvers/compression.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import multiprocessing
from pathlib import Path
import typing as tp
import flashy
import omegaconf
import torch
from torch import nn
from . import base, builders
from .. import models, quantization
from ..utils import checkpoint
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_pool_executor
logger = logging.getLogger(__name__)
class CompressionSolver(base.StandardSolver):
"""Solver for compression task.
The compression task combines a set of perceptual and objective losses
to train an EncodecModel (composed of an encoder-decoder and a quantizer)
to perform high fidelity audio reconstruction.
"""
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
self.rng: torch.Generator # set at each epoch
self.adv_losses = builders.get_adversarial_losses(self.cfg)
self.aux_losses = nn.ModuleDict()
self.info_losses = nn.ModuleDict()
assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
loss_weights = dict()
for loss_name, weight in self.cfg.losses.items():
if loss_name in ['adv', 'feat']:
for adv_name, _ in self.adv_losses.items():
loss_weights[f'{loss_name}_{adv_name}'] = weight
elif weight > 0:
self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
loss_weights[loss_name] = weight
else:
self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
self.register_stateful('adv_losses')
@property
def best_metric_name(self) -> tp.Optional[str]:
# best model is the last for the compression model
return None
def build_model(self):
"""Instantiate model and optimizer."""
# Model and optimizer
self.model = models.builders.get_compression_model(self.cfg).to(self.device)
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
self.register_stateful('model', 'optimizer')
self.register_best_state('model')
self.register_ema('model')
def build_dataloaders(self):
"""Instantiate audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg)
def show(self):
"""Show the compression model and employed adversarial loss."""
self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
self.log_model_summary(self.model)
self.logger.info("Adversarial loss:")
self.log_model_summary(self.adv_losses)
self.logger.info("Auxiliary losses:")
self.logger.info(self.aux_losses)
self.logger.info("Info losses:")
self.logger.info(self.info_losses)
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
"""Perform one training or valid step on a given batch."""
x = batch.to(self.device)
y = x.clone()
qres = self.model(x)
assert isinstance(qres, quantization.QuantizedResult)
y_pred = qres.x
# Log bandwidth in kb/s
metrics['bandwidth'] = qres.bandwidth.mean()
if self.is_training:
d_losses: dict = {}
if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
for adv_name, adversary in self.adv_losses.items():
disc_loss = adversary.train_adv(y_pred, y)
d_losses[f'd_{adv_name}'] = disc_loss
metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
metrics.update(d_losses)
balanced_losses: dict = {}
other_losses: dict = {}
# penalty from quantization
if qres.penalty is not None and qres.penalty.requires_grad:
other_losses['penalty'] = qres.penalty # penalty term from the quantizer
# adversarial losses
for adv_name, adversary in self.adv_losses.items():
adv_loss, feat_loss = adversary(y_pred, y)
balanced_losses[f'adv_{adv_name}'] = adv_loss
balanced_losses[f'feat_{adv_name}'] = feat_loss
# auxiliary losses
for loss_name, criterion in self.aux_losses.items():
loss = criterion(y_pred, y)
balanced_losses[loss_name] = loss
# weighted losses
metrics.update(balanced_losses)
metrics.update(other_losses)
metrics.update(qres.metrics)
if self.is_training:
# backprop losses that are not handled by balancer
other_loss = torch.tensor(0., device=self.device)
if 'penalty' in other_losses:
other_loss += other_losses['penalty']
if other_loss.requires_grad:
other_loss.backward(retain_graph=True)
ratio1 = sum(p.grad.data.norm(p=2).pow(2)
for p in self.model.parameters() if p.grad is not None)
assert isinstance(ratio1, torch.Tensor)
metrics['ratio1'] = ratio1.sqrt()
# balancer losses backward, returns effective training loss
# with effective weights at the current batch.
metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
# add metrics corresponding to weight ratios
metrics.update(self.balancer.metrics)
ratio2 = sum(p.grad.data.norm(p=2).pow(2)
for p in self.model.parameters() if p.grad is not None)
assert isinstance(ratio2, torch.Tensor)
metrics['ratio2'] = ratio2.sqrt()
# optim
flashy.distrib.sync_model(self.model)
if self.cfg.optim.max_norm:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.optim.max_norm
)
self.optimizer.step()
self.optimizer.zero_grad()
# informative losses only
info_losses: dict = {}
with torch.no_grad():
for loss_name, criterion in self.info_losses.items():
loss = criterion(y_pred, y)
info_losses[loss_name] = loss
metrics.update(info_losses)
# aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')]
if len(adv_losses) > 0:
metrics['adv'] = torch.sum(torch.stack(adv_losses))
feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')]
if len(feat_losses) > 0:
metrics['feat'] = torch.sum(torch.stack(feat_losses))
return metrics
def run_epoch(self):
# reset random seed at the beginning of the epoch
self.rng = torch.Generator()
self.rng.manual_seed(1234 + self.epoch)
# run epoch
super().run_epoch()
def evaluate(self):
"""Evaluate stage. Runs audio reconstruction evaluation."""
self.model.eval()
evaluate_stage_name = str(self.current_stage)
loader = self.dataloaders['evaluate']
updates = len(loader)
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
average = flashy.averager()
pendings = []
ctx = multiprocessing.get_context('spawn')
with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
for idx, batch in enumerate(lp):
x = batch.to(self.device)
with torch.no_grad():
qres = self.model(x)
y_pred = qres.x.cpu()
y = batch.cpu() # should already be on CPU but just in case
pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
for pending in metrics_lp:
metrics = pending.result()
metrics = average(metrics)
metrics = flashy.distrib.average_metrics(metrics, len(loader))
return metrics
def generate(self):
"""Generate stage."""
self.model.eval()
sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
generate_stage_name = str(self.current_stage)
loader = self.dataloaders['generate']
updates = len(loader)
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
for batch in lp:
reference, _ = batch
reference = reference.to(self.device)
with torch.no_grad():
qres = self.model(reference)
assert isinstance(qres, quantization.QuantizedResult)
reference = reference.cpu()
estimate = qres.x.cpu()
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
flashy.distrib.barrier()
def load_from_pretrained(self, name: str) -> dict:
model = models.CompressionModel.get_pretrained(name)
if isinstance(model, models.DAC):
raise RuntimeError("Cannot fine tune a DAC model.")
elif isinstance(model, models.HFEncodecCompressionModel):
self.logger.warning('Trying to automatically convert a HuggingFace model '
'to AudioCraft, this might fail!')
state = model.model.state_dict()
new_state = {}
for k, v in state.items():
if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
# We need to determine if this a convtr or a regular conv.
layer = int(k.split('.')[2])
if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
k = k.replace('.conv.', '.convtr.')
k = k.replace('encoder.layers.', 'encoder.model.')
k = k.replace('decoder.layers.', 'decoder.model.')
k = k.replace('conv.', 'conv.conv.')
k = k.replace('convtr.', 'convtr.convtr.')
k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
k = k.replace('.codebook.', '._codebook.')
new_state[k] = v
state = new_state
elif isinstance(model, models.EncodecModel):
state = model.state_dict()
else:
raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
return {
'best_state': {'model': state}
}
@staticmethod
def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
"""Instantiate a CompressionModel from a given checkpoint path or dora sig.
This method is a convenient endpoint to load a CompressionModel to use in other solvers.
Args:
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
This also supports pre-trained models by using a path of the form //pretrained/NAME.
See `model_from_pretrained` for a list of supported pretrained models.
use_ema (bool): Use EMA variant of the model instead of the actual model.
device (torch.device or str): Device on which the model is loaded.
"""
checkpoint_path = str(checkpoint_path)
if checkpoint_path.startswith('//pretrained/'):
name = checkpoint_path.split('/', 3)[-1]
return models.CompressionModel.get_pretrained(name, device)
logger = logging.getLogger(__name__)
logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
_checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
state = checkpoint.load_checkpoint(_checkpoint_path)
assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
cfg = state['xp.cfg']
cfg.device = device
compression_model = models.builders.get_compression_model(cfg).to(device)
assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
assert 'best_state' in state and state['best_state'] != {}
assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
compression_model.load_state_dict(state['best_state']['model'])
compression_model.eval()
logger.info("Compression model loaded!")
return compression_model
@staticmethod
def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
checkpoint_path: tp.Union[Path, str],
device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
"""Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
Args:
cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
use_ema (bool): Use EMA variant of the model instead of the actual model.
device (torch.device or str): Device on which the model is loaded.
"""
compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
return compression_model
def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict:
"""Audio reconstruction evaluation method that can be conveniently pickled."""
metrics = {}
if cfg.evaluate.metrics.visqol:
visqol = builders.get_visqol(cfg.metrics.visqol)
metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate)
sisnr = builders.get_loss('sisnr', cfg)
metrics['sisnr'] = sisnr(y_pred, y)
return metrics
================================================
FILE: audiocraft/solvers/diffusion.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import flashy
import julius
import omegaconf
import torch
import torch.nn.functional as F
from . import builders
from . import base
from .. import models
from ..modules.diffusion_schedule import NoiseSchedule
from ..metrics import RelativeVolumeMel
from ..models.builders import get_processor
from ..utils.samples.manager import SampleManager
from ..solvers.compression import CompressionSolver
class PerStageMetrics:
"""Handle prompting the metrics per stage.
It outputs the metrics per range of diffusion states.
e.g. avg loss when t in [250, 500]
"""
def __init__(self, num_steps: int, num_stages: int = 4):
self.num_steps = num_steps
self.num_stages = num_stages
def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
if type(step) is int:
stage = int((step / self.num_steps) * self.num_stages)
return {f"{name}_{stage}": loss for name, loss in losses.items()}
elif type(step) is torch.Tensor:
stage_tensor = ((step / self.num_steps) * self.num_stages).long()
out: tp.Dict[str, float] = {}
for stage_idx in range(self.num_stages):
mask = (stage_tensor == stage_idx)
N = mask.sum()
stage_out = {}
if N > 0: # pass if no elements in the stage
for name, loss in losses.items():
stage_loss = (mask * loss).sum() / N
stage_out[f"{name}_{stage_idx}"] = stage_loss
out = {**out, **stage_out}
return out
class DataProcess:
"""Apply filtering or resampling.
Args:
initial_sr (int): Initial sample rate.
target_sr (int): Target sample rate.
use_resampling: Whether to use resampling or not.
use_filter (bool):
n_bands (int): Number of bands to consider.
idx_band (int):
device (torch.device or str):
cutoffs ():
boost (bool):
"""
def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
use_filter: bool = False, n_bands: int = 4,
idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
"""Apply filtering or resampling
Args:
initial_sr (int): sample rate of the dataset
target_sr (int): sample rate after resampling
use_resampling (bool): whether or not performs resampling
use_filter (bool): when True filter the data to keep only one frequency band
n_bands (int): Number of bands used
cuts (none or list): The cutoff frequencies of the band filtering
if None then we use mel scale bands.
idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
boost (bool): make the data scale match our music dataset.
"""
assert idx_band < n_bands
self.idx_band = idx_band
if use_filter:
if cutoffs is not None:
self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
else:
self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
self.use_filter = use_filter
self.use_resampling = use_resampling
self.target_sr = target_sr
self.initial_sr = initial_sr
self.boost = boost
def process_data(self, x, metric=False):
if x is None:
return None
if self.boost:
x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
x * 0.22
if self.use_filter and not metric:
x = self.filter(x)[self.idx_band]
if self.use_resampling:
x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
return x
def inverse_process(self, x):
"""Upsampling only."""
if self.use_resampling:
x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
return x
class DiffusionSolver(base.StandardSolver):
"""Solver for compression task.
The diffusion task allows for MultiBand diffusion model training.
Args:
cfg (DictConfig): Configuration.
"""
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
self.cfg = cfg
self.device = cfg.device
self.sample_rate: int = self.cfg.sample_rate
self.codec_model = CompressionSolver.model_from_checkpoint(
cfg.compression_model_checkpoint, device=self.device)
self.codec_model.set_num_codebooks(cfg.n_q)
assert self.codec_model.sample_rate == self.cfg.sample_rate, (
f"Codec model sample rate is {self.codec_model.sample_rate} but "
f"Solver sample rate is {self.cfg.sample_rate}."
)
assert self.codec_model.sample_rate == self.sample_rate, \
f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
"don't match."
self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
self.register_stateful('sample_processor')
self.sample_processor.to(self.device)
self.schedule = NoiseSchedule(
**cfg.schedule, device=self.device, sample_processor=self.sample_processor)
self.eval_metric: tp.Optional[torch.nn.Module] = None
self.rvm = RelativeVolumeMel()
self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
idx_band=cfg.filter.idx_band, device=self.device)
@property
def best_metric_name(self) -> tp.Optional[str]:
if self._current_stage == "evaluate":
return 'rvm'
else:
return 'loss'
@torch.no_grad()
def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
codes, scale = self.codec_model.encode(wav)
assert scale is None, "Scaled compression models not supported."
emb = self.codec_model.decode_latent(codes)
return emb
def build_model(self):
"""Build model and optimizer as well as optional Exponential Moving Average of the model.
"""
# Model and optimizer
self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
self.register_stateful('model', 'optimizer')
self.register_best_state('model')
self.register_ema('model')
def build_dataloaders(self):
"""Build audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg)
def show(self):
# TODO
raise NotImplementedError()
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
"""Perform one training or valid step on a given batch."""
x = batch.to(self.device)
loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
condition = self.get_condition(x) # [bs, 128, T/hop, n_emb]
sample = self.data_processor.process_data(x)
input_, target, step = self.schedule.get_training_item(sample,
tensor_step=self.cfg.schedule.variable_step_batch)
out = self.model(input_, step, condition=condition).sample
base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
loss = base_loss / reference_loss ** self.cfg.loss.norm_power
if self.is_training:
loss.mean().backward()
flashy.distrib.sync_model(self.model)
self.optimizer.step()
self.optimizer.zero_grad()
metrics = {
'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
}
metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
metrics.update({
'std_in': input_.std(), 'std_out': out.std()})
return metrics
def run_epoch(self):
# reset random seed at the beginning of the epoch
self.rng = torch.Generator()
self.rng.manual_seed(1234 + self.epoch)
self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
# run epoch
super().run_epoch()
def evaluate(self):
"""Evaluate stage.
Runs audio reconstruction evaluation.
"""
self.model.eval()
evaluate_stage_name = f'{self.current_stage}'
loader = self.dataloaders['evaluate']
updates = len(loader)
lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
metrics = {}
n = 1
for idx, batch in enumerate(lp):
x = batch.to(self.device)
with torch.no_grad():
y_pred = self.regenerate(x)
y_pred = y_pred.cpu()
y = batch.cpu() # should already be on CPU but just in case
rvm = self.rvm(y_pred, y)
lp.update(**rvm)
if len(metrics) == 0:
metrics = rvm
else:
for key in rvm.keys():
metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
metrics = flashy.distrib.average_metrics(metrics)
return metrics
@torch.no_grad()
def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
"""Regenerate the given waveform."""
condition = self.get_condition(wav)
initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes.
result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
step_list=step_list)
result = self.data_processor.inverse_process(result)
return result
def generate(self):
"""Generate stage."""
sample_manager = SampleManager(self.xp)
self.model.eval()
generate_stage_name = f'{self.current_stage}'
loader = self.dataloaders['generate']
updates = len(loader)
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
for batch in lp:
reference, _ = batch
reference = reference.to(self.device)
estimate = self.regenerate(reference)
reference = reference.cpu()
estimate = estimate.cpu()
sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
flashy.distrib.barrier()
================================================
FILE: audiocraft/solvers/musicgen.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import time
import typing as tp
import flashy
import math
import omegaconf
import torch
from torch.nn import functional as F
from . import base, builders
from .compression import CompressionSolver
from .. import metrics as eval_metrics
from .. import models
from ..data.audio_dataset import AudioDataset
from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
from ..data.audio_utils import normalize_audio
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
from ..utils.cache import CachedBatchWriter, CachedBatchLoader
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
class MusicGenSolver(base.StandardSolver):
"""Solver for MusicGen training task.
Used in: https://arxiv.org/abs/2306.05284
"""
DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
# easier access to sampling parameters
self.generation_params = {
'use_sampling': self.cfg.generate.lm.use_sampling,
'temp': self.cfg.generate.lm.temp,
'top_k': self.cfg.generate.lm.top_k,
'top_p': self.cfg.generate.lm.top_p,
}
self._best_metric_name: tp.Optional[str] = 'ce'
self._cached_batch_writer = None
self._cached_batch_loader = None
if cfg.cache.path:
if cfg.cache.write:
self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
if self.cfg.cache.write_num_shards:
self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
self._best_metric_name = None
else:
self._cached_batch_loader = CachedBatchLoader(
Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
min_length=self.cfg.optim.updates_per_epoch or 1)
self.dataloaders['original_train'] = self.dataloaders['train']
self.dataloaders['train'] = self._cached_batch_loader # type: ignore
@staticmethod
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
device: tp.Optional[str] = None, autocast: bool = True,
batch_size: tp.Optional[int] = None,
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
**kwargs):
"""Mostly a convenience function around magma.train.get_solver_from_sig,
populating all the proper param, deactivating EMA, FSDP, loading the best state,
basically all you need to get a solver ready to "play" with in single GPU mode
and with minimal memory overhead.
Args:
sig (str): signature to load.
dtype (str or None): potential dtype, as a string, i.e. 'float16'.
device (str or None): potential device, as a string, i.e. 'cuda'.
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
"""
from audiocraft import train
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
our_override_cfg['autocast'] = autocast
if dtype is not None:
our_override_cfg['dtype'] = dtype
if device is not None:
our_override_cfg['device'] = device
if batch_size is not None:
our_override_cfg['dataset'] = {'batch_size': batch_size}
if override_cfg is None:
override_cfg = {}
override_cfg = omegaconf.OmegaConf.merge(
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
solver = train.get_solver_from_sig(
sig, override_cfg=override_cfg,
load_best=True, disable_fsdp=True,
ignore_state_keys=['optimizer', 'ema'], **kwargs)
solver.model.eval()
return solver
def get_formatter(self, stage_name: str) -> flashy.Formatter:
return flashy.Formatter({
'lr': '.2E',
'ce': '.3f',
'ppl': '.3f',
'grad_norm': '.3E',
}, exclude_keys=['ce_q*', 'ppl_q*'])
@property
def best_metric_name(self) -> tp.Optional[str]:
return self._best_metric_name
def build_model(self) -> None:
"""Instantiate models and optimizer."""
# we can potentially not use all quantizers with which the EnCodec model was trained
# (e.g. we trained the model with quantizers dropout)
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
assert self.compression_model.sample_rate == self.cfg.sample_rate, (
f"Compression model sample rate is {self.compression_model.sample_rate} but "
f"Solver sample rate is {self.cfg.sample_rate}."
)
# ensure we have matching configuration between LM and compression model
assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
"Cardinalities of the LM and compression model don't match: ",
f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
f"compression model cardinality is {self.compression_model.cardinality}"
)
assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
"Numbers of codebooks of the LM and compression models don't match: ",
f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
)
self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
self.compression_model.num_codebooks, self.compression_model.cardinality,
self.compression_model.frame_rate)
# instantiate LM model
self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
if self.cfg.fsdp.use:
assert not self.cfg.autocast, "Cannot use autocast with fsdp"
self.model = self.wrap_with_fsdp(self.model)
self.register_ema('model')
# initialize optimization
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
self.register_best_state('model')
self.autocast_dtype = {
'float16': torch.float16, 'bfloat16': torch.bfloat16
}[self.cfg.autocast_dtype]
self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
if self.cfg.fsdp.use:
need_scaler = self.cfg.fsdp.param_dtype == 'float16'
else:
need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
if need_scaler:
if self.cfg.fsdp.use:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
self.scaler = ShardedGradScaler() # type: ignore
else:
self.scaler = torch.cuda.amp.GradScaler()
self.register_stateful('scaler')
def build_dataloaders(self) -> None:
"""Instantiate audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
def show(self) -> None:
"""Show the compression model and LM model."""
self.logger.info("Compression model:")
self.log_model_summary(self.compression_model)
self.logger.info("LM model:")
self.log_model_summary(self.model)
def load_state_dict(self, state: dict) -> None:
if 'condition_provider' in state:
model_state = state['model']
condition_provider_state = state.pop('condition_provider')
prefix = 'condition_provider.'
for key, value in condition_provider_state.items():
key = prefix + key
assert key not in model_state
model_state[key] = value
super().load_state_dict(state)
def load_from_pretrained(self, name: str):
# TODO: support native HF versions of MusicGen.
lm_pkg = models.loaders.load_lm_model_ckpt(name)
state: dict = {
'best_state': {
'model': lm_pkg['best_state'],
},
}
return state
def _compute_cross_entropy(
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook: tp.List[torch.Tensor] = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
@torch.no_grad()
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
"""Prepare input batchs for language model training.
Args:
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
and corresponding metadata as SegmentWithAttributes (with B items).
check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
Returns:
Condition tensors (dict[str, any]): Preprocessed condition attributes.
Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
with B the batch size, K the number of codebooks, T_s the token timesteps.
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
"""
if self._cached_batch_loader is None or self.current_stage != "train":
audio, infos = batch
audio = audio.to(self.device)
audio_tokens = None
assert audio.size(0) == len(infos), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in metadata ({len(infos)})"
)
else:
audio = None
# In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
infos, = batch # type: ignore
assert all([isinstance(info, AudioInfo) for info in infos])
assert all([info.audio_tokens is not None for info in infos]) # type: ignore
audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore
audio_tokens = audio_tokens.long()
for info in infos:
if isinstance(info, MusicInfo):
# Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
# then you must be using the chroma cache! otherwise the code will try
# to use this segment and fail (by that I mean you will see NaN everywhere).
info.self_wav = WavCondition(
torch.full([1, info.channels, info.total_frames], float('NaN')),
length=torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate],
path=[info.meta.path],
seek_time=[info.seek_time])
dataset = get_dataset_from_loader(self.dataloaders['original_train'])
assert isinstance(dataset, MusicDataset), type(dataset)
if dataset.paraphraser is not None and info.description is not None:
# Hackingly reapplying paraphraser when using cache.
info.description = dataset.paraphraser.sample_paraphrase(
info.meta.path, info.description)
# prepare attributes
attributes = [info.to_condition_attributes() for info in infos]
attributes = self.model.cfg_dropout(attributes)
attributes = self.model.att_dropout(attributes)
tokenized = self.model.condition_provider.tokenize(attributes)
# Now we should be synchronization free.
if self.device == "cuda" and check_synchronization_points:
torch.cuda.set_sync_debug_mode("warn")
if audio_tokens is None:
with torch.no_grad():
audio_tokens, scale = self.compression_model.encode(audio)
assert scale is None, "Scaled compression model not supported with LM."
with self.autocast:
condition_tensors = self.model.condition_provider(tokenized)
# create a padding mask to hold valid vs invalid positions
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
# replace encodec tokens from padded audio with special_token_id
if self.cfg.tokens.padding_with_special_token:
audio_tokens = audio_tokens.clone()
padding_mask = padding_mask.clone()
token_sample_rate = self.compression_model.frame_rate
B, K, T_s = audio_tokens.shape
for i in range(B):
n_samples = infos[i].n_frames
audio_sample_rate = infos[i].sample_rate
# take the last token generated from actual audio frames (non-padded audio)
valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
padding_mask[i, :, valid_tokens:] = 0
if self.device == "cuda" and check_synchronization_points:
torch.cuda.set_sync_debug_mode("default")
if self._cached_batch_writer is not None and self.current_stage == 'train':
assert self._cached_batch_loader is None
assert audio_tokens is not None
for info, one_audio_tokens in zip(infos, audio_tokens):
assert isinstance(info, AudioInfo)
if isinstance(info, MusicInfo):
assert not info.joint_embed, "joint_embed and cache not supported yet."
info.self_wav = None
assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
info.audio_tokens = one_audio_tokens.short().cpu()
self._cached_batch_writer.save(infos)
return condition_tensors, audio_tokens, padding_mask
def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
"""Perform one training or valid step on a given batch."""
check_synchronization_points = idx == 1 and self.device == 'cuda'
condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
batch, check_synchronization_points)
self.deadlock_detect.update('tokens_and_conditions')
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('warn')
with self.autocast:
model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore
logits = model_output.logits
mask = padding_mask & model_output.mask
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
loss = ce
self.deadlock_detect.update('loss')
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('default')
if self.is_training:
metrics['lr'] = self.optimizer.param_groups[0]['lr']
if self.scaler is not None:
loss = self.scaler.scale(loss)
self.deadlock_detect.update('scale')
if self.cfg.fsdp.use:
loss.backward()
flashy.distrib.average_tensors(self.model.buffers())
elif self.cfg.optim.eager_sync:
with flashy.distrib.eager_sync_model(self.model):
loss.backward()
else:
# this should always be slower but can be useful
# for weird use cases like multiple backwards.
loss.backward()
flashy.distrib.sync_model(self.model)
self.deadlock_detect.update('backward')
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
if self.cfg.optim.max_norm:
if self.cfg.fsdp.use:
metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
else:
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.optim.max_norm
)
if self.scaler is None:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
if self.lr_scheduler:
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.deadlock_detect.update('optim')
if self.scaler is not None:
scale = self.scaler.get_scale()
metrics['grad_scale'] = scale
if not loss.isfinite().all():
raise RuntimeError("Model probably diverged.")
metrics['ce'] = ce
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
return metrics
@torch.no_grad()
def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
gen_duration: float, prompt_duration: tp.Optional[float] = None,
remove_prompt: bool = False,
**generation_params) -> dict:
"""Run generate step on a batch of optional audio tensor and corresponding attributes.
Args:
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
gen_duration (float): Target audio duration for the generation.
prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
generation_params: Additional generation parameters.
Returns:
gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
and the prompt along with additional information.
"""
bench_start = time.time()
audio, meta = batch
assert audio.size(0) == len(meta), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in metadata ({len(meta)})"
)
# prepare attributes
attributes = [x.to_condition_attributes() for x in meta]
# TODO: Add dropout for chroma?
# prepare audio prompt
if prompt_duration is None:
prompt_audio = None
else:
assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
prompt_audio = audio[..., :prompt_audio_frames]
# get audio tokens from compression model
if prompt_audio is None or prompt_audio.nelement() == 0:
num_samples = len(attributes)
prompt_tokens = None
else:
num_samples = None
prompt_audio = prompt_audio.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt_audio)
assert scale is None, "Compression model in MusicGen should not require rescaling."
# generate by sampling from the LM
with self.autocast:
total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
gen_tokens = self.model.generate(
prompt_tokens, attributes, max_gen_len=total_gen_len,
num_samples=num_samples, **self.generation_params)
# generate audio from tokens
assert gen_tokens.dim() == 3
gen_audio = self.compression_model.decode(gen_tokens, None)
bench_end = time.time()
gen_outputs = {
'rtf': (bench_end - bench_start) / gen_duration,
'ref_audio': audio,
'gen_audio': gen_audio,
'gen_tokens': gen_tokens,
'prompt_audio': prompt_audio,
'prompt_tokens': prompt_tokens,
}
return gen_outputs
def generate_audio(self) -> dict:
"""Audio generation stage."""
generate_stage_name = f'{self.current_stage}'
sample_manager = SampleManager(self.xp)
self.logger.info(f"Generating samples in {sample_manager.base_folder}")
loader = self.dataloaders['generate']
updates = len(loader)
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
dataset = get_dataset_from_loader(loader)
dataset_duration = dataset.segment_duration
assert dataset_duration is not None
assert isinstance(dataset, AudioDataset)
target_duration = self.cfg.generate.lm.gen_duration
prompt_duration = self.cfg.generate.lm.prompt_duration
if target_duration is None:
target_duration = dataset_duration
if prompt_duration is None:
prompt_duration = dataset_duration / 4
assert prompt_duration < dataset_duration, (
f"Specified prompt duration ({prompt_duration}s) is longer",
f" than reference audio duration ({dataset_duration}s)"
)
def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
hydrated_conditions = []
for sample in [x.to_condition_attributes() for x in meta]:
cond_dict = {}
for cond_type in sample.__annotations__.keys():
for cond_key, cond_val in getattr(sample, cond_type).items():
if cond_key not in self.model.condition_provider.conditioners.keys():
continue
if is_jsonable(cond_val):
cond_dict[cond_key] = cond_val
elif isinstance(cond_val, WavCondition):
cond_dict[cond_key] = cond_val.path
elif isinstance(cond_val, JointEmbedCondition):
cond_dict[cond_key] = cond_val.text # only support text at inference for now
else:
# if we reached this point, it is not clear how to log the condition
# so we just log the type.
cond_dict[cond_key] = str(type(cond_val))
continue
hydrated_conditions.append(cond_dict)
return hydrated_conditions
metrics: dict = {}
average = flashy.averager()
for batch in lp:
audio, meta = batch
# metadata for sample manager
hydrated_conditions = get_hydrated_conditions(meta)
sample_generation_params = {
**{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
**self.generation_params
}
if self.cfg.generate.lm.unprompted_samples:
if self.cfg.generate.lm.gen_gt_samples:
# get the ground truth instead of generation
self.logger.warn(
"Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
gen_unprompted_audio = audio
rtf = 1.
else:
gen_unprompted_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
**self.generation_params)
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
rtf = gen_unprompted_outputs['rtf']
sample_manager.add_samples(
gen_unprompted_audio, self.epoch, hydrated_conditions,
ground_truth_wavs=audio, generation_args=sample_generation_params)
if self.cfg.generate.lm.prompted_samples:
gen_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
**self.generation_params)
gen_audio = gen_outputs['gen_audio'].cpu()
prompt_audio = gen_outputs['prompt_audio'].cpu()
sample_manager.add_samples(
gen_audio, self.epoch, hydrated_conditions,
prompt_wavs=prompt_audio, ground_truth_wavs=audio,
generation_args=sample_generation_params)
metrics['rtf'] = rtf
metrics = average(metrics)
flashy.distrib.barrier()
return metrics
def generate(self) -> dict:
"""Generate stage."""
self.model.eval()
with torch.no_grad():
return self.generate_audio()
def run_epoch(self):
if self.cfg.cache.write:
if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
return
super().run_epoch()
def train(self):
"""Train stage.
"""
if self._cached_batch_writer is not None:
self._cached_batch_writer.start_epoch(self.epoch)
if self._cached_batch_loader is None:
dataset = get_dataset_from_loader(self.dataloaders['train'])
assert isinstance(dataset, AudioDataset)
dataset.current_epoch = self.epoch
else:
self._cached_batch_loader.start_epoch(self.epoch)
return super().train()
def evaluate_audio_generation(self) -> dict:
"""Evaluate audio generation with off-the-shelf metrics."""
evaluate_stage_name = f'{self.current_stage}_generation'
# instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
should_run_eval = False
eval_chroma_wavs: tp.Optional[torch.Tensor] = None
if self.cfg.evaluate.metrics.fad:
fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.kld:
kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.text_consistency:
text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.chroma_cosine:
chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
# if we have predefind wavs for chroma we should purge them for computing the cosine metric
has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
if has_predefined_eval_chromas:
warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
'Resetting eval chromas to None for evaluation.')
eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore
self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore
should_run_eval = True
def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
compressed_audio = self.compression_model.decode(audio_tokens, scale)
return compressed_audio[..., :audio.shape[-1]]
metrics: dict = {}
if should_run_eval:
loader = self.dataloaders['evaluate']
updates = len(loader)
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
average = flashy.averager()
dataset = get_dataset_from_loader(loader)
assert isinstance(dataset, AudioDataset)
self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
for idx, batch in enumerate(lp):
audio, meta = batch
assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
target_duration = audio.shape[-1] / self.cfg.sample_rate
if self.cfg.evaluate.fixed_generation_duration:
target_duration = self.cfg.evaluate.fixed_generation_duration
gen_outputs = self.run_generate_step(
batch, gen_duration=target_duration,
**self.generation_params
)
y_pred = gen_outputs['gen_audio'].detach()
y_pred = y_pred[..., :audio.shape[-1]]
normalize_kwargs = dict(self.cfg.generate.audio)
normalize_kwargs.pop('format', None)
y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
y = audio.cpu() # should already be on CPU but just in case
sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding
sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples
audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
if fad is not None:
if self.cfg.metrics.fad.use_gt:
y_pred = get_compressed_audio(y).cpu()
fad.update(y_pred, y, sizes, sample_rates, audio_stems)
if kldiv is not None:
if self.cfg.metrics.kld.use_gt:
y_pred = get_compressed_audio(y).cpu()
kldiv.update(y_pred, y, sizes, sample_rates)
if text_consistency is not None:
texts = [m.description for m in meta]
if self.cfg.metrics.text_consistency.use_gt:
y_pred = y
text_consistency.update(y_pred, texts, sizes, sample_rates)
if chroma_cosine is not None:
if self.cfg.metrics.chroma_cosine.use_gt:
y_pred = get_compressed_audio(y).cpu()
chroma_cosine.update(y_pred, y, sizes, sample_rates)
# restore chroma conditioner's eval chroma wavs
if eval_chroma_wavs is not None:
self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
flashy.distrib.barrier()
if fad is not None:
metrics['fad'] = fad.compute()
if kldiv is not None:
kld_metrics = kldiv.compute()
metrics.update(kld_metrics)
if text_consistency is not None:
metrics['text_consistency'] = text_consistency.compute()
if chroma_cosine is not None:
metrics['chroma_cosine'] = chroma_cosine.compute()
metrics = average(metrics)
metrics = flashy.distrib.average_metrics(metrics, len(loader))
return metrics
def evaluate(self) -> dict:
"""Evaluate stage."""
self.model.eval()
with torch.no_grad():
metrics: dict = {}
if self.cfg.evaluate.metrics.base:
metrics.update(self.common_train_valid('evaluate'))
gen_metrics = self.evaluate_audio_generation()
return {**metrics, **gen_metrics}
================================================
FILE: audiocraft/train.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Entry point for dora to launch solvers for running training loops.
See more info on how to use dora: https://github.com/facebookresearch/dora
"""
import logging
import multiprocessing
import os
import sys
import typing as tp
from dora import git_save, hydra_main, XP
import flashy
import hydra
import omegaconf
from .environment import AudioCraftEnvironment
from .utils.cluster import get_slurm_parameters
logger = logging.getLogger(__name__)
def resolve_config_dset_paths(cfg):
"""Enable Dora to load manifest from git clone repository."""
# manifest files for the different splits
for key, value in cfg.datasource.items():
if isinstance(value, str):
cfg.datasource[key] = git_save.to_absolute_path(value)
def get_solver(cfg):
from . import solvers
# Convert batch size to batch size for each GPU
assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
cfg.dataset.batch_size //= flashy.distrib.world_size()
for split in ['train', 'valid', 'evaluate', 'generate']:
if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
cfg.dataset[split].batch_size //= flashy.distrib.world_size()
resolve_config_dset_paths(cfg)
solver = solvers.get_solver(cfg)
return solver
def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
restore: bool = True, load_best: bool = True,
ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
"""Given a XP, return the Solver object.
Args:
xp (XP): Dora experiment for which to retrieve the solver.
override_cfg (dict or None): If not None, should be a dict used to
override some values in the config of `xp`. This will not impact
the XP signature or folder. The format is different
than the one used in Dora grids, nested keys should actually be nested dicts,
not flattened, e.g. `{'optim': {'batch_size': 32}}`.
restore (bool): If `True` (the default), restore state from the last checkpoint.
load_best (bool): If `True` (the default), load the best state from the checkpoint.
ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
disable_fsdp (bool): if True, disables FSDP entirely. This will
also automatically skip loading the EMA. For solver specific
state sources, like the optimizer, you might want to
use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
"""
logger.info(f"Loading solver from XP {xp.sig}. "
f"Overrides used: {xp.argv}")
cfg = xp.cfg
if override_cfg is not None:
cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
if disable_fsdp and cfg.fsdp.use:
cfg.fsdp.use = False
assert load_best is True
# ignoring some keys that were FSDP sharded like model, ema, and best_state.
# fsdp_best_state will be used in that case. When using a specific solver,
# one is responsible for adding the relevant keys, e.g. 'optimizer'.
# We could make something to automatically register those inside the solver, but that
# seem overkill at this point.
ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']
try:
with xp.enter():
solver = get_solver(cfg)
if restore:
solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
return solver
finally:
hydra.core.global_hydra.GlobalHydra.instance().clear()
def get_solver_from_sig(sig: str, *args, **kwargs):
"""Return Solver object from Dora signature, i.e. to play with it from a notebook.
See `get_solver_from_xp` for more information.
"""
xp = main.get_xp_from_sig(sig)
return get_solver_from_xp(xp, *args, **kwargs)
def init_seed_and_system(cfg):
import numpy as np
import torch
import random
from audiocraft.modules.transformer import set_efficient_attention_backend
multiprocessing.set_start_method(cfg.mp_start_method)
logger.debug('Setting mp start method to %s', cfg.mp_start_method)
random.seed(cfg.seed)
np.random.seed(cfg.seed)
# torch also initialize cuda seed if available
torch.manual_seed(cfg.seed)
torch.set_num_threads(cfg.num_threads)
os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
logger.debug('Setting num threads to %d', cfg.num_threads)
set_efficient_attention_backend(cfg.efficient_attention_backend)
logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
@hydra_main(config_path='../config', config_name='config', version_base='1.1')
def main(cfg):
init_seed_and_system(cfg)
# Setup logging both to XP specific folder, and to stderr.
log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
# Initialize distributed training, no need to specify anything when using Dora.
flashy.distrib.init()
solver = get_solver(cfg)
if cfg.show:
solver.show()
return
if cfg.execute_only:
assert cfg.execute_inplace or cfg.continue_from is not None, \
"Please explicitly specify the checkpoint to continue from with continue_from= " + \
"when running with execute_only or set execute_inplace to True."
solver.restore(replay_metrics=False) # load checkpoint
solver.run_one_stage(cfg.execute_only)
return
return solver.run()
main.dora.dir = AudioCraftEnvironment.get_dora_dir()
main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)
if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
main.dora.shared = None
if __name__ == '__main__':
main()
================================================
FILE: audiocraft/utils/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities."""
================================================
FILE: audiocraft/utils/autocast.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
class TorchAutocast:
"""TorchAutocast utility class.
Allows you to enable and disable autocast. This is specially useful
when dealing with different architectures and clusters with different
levels of support.
Args:
enabled (bool): Whether to enable torch.autocast or not.
args: Additional args for torch.autocast.
kwargs: Additional kwargs for torch.autocast
"""
def __init__(self, enabled: bool, *args, **kwargs):
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
def __enter__(self):
if self.autocast is None:
return
try:
self.autocast.__enter__()
except RuntimeError:
device = self.autocast.device
dtype = self.autocast.fast_dtype
raise RuntimeError(
f"There was an error autocasting with dtype={dtype} device={device}\n"
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
)
def __exit__(self, *args, **kwargs):
if self.autocast is None:
return
self.autocast.__exit__(*args, **kwargs)
================================================
FILE: audiocraft/utils/best_state.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import logging
import typing as tp
import flashy
import torch
from ..optim import ModuleDictEMA
from .utils import copy_state
logger = logging.getLogger(__name__)
class BestStateDictManager(flashy.state.StateDictSource):
"""BestStateDictManager maintains a copy of best state_dict() for registered sources.
BestStateDictManager has two main attributes:
states (dict): State dict of the registered StateDictSource.
param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
When registering new sources, the BestStateDictManager will ensure two conflicting sources between
ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
what to consider for best state.
Args:
device (torch.device or str): Device on which we keep the copy.
dtype (torch.dtype): Data type for the state parameters.
"""
def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
dtype: tp.Optional[torch.dtype] = None):
self.device = device
self.states: dict = {}
self.param_ids: dict = defaultdict(dict)
self.dtype = dtype
def _get_parameter_ids(self, state_dict):
return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
for registered_name, registered_param_ids in self.param_ids.items():
if registered_name != name:
overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
def update(self, name: str, source: flashy.state.StateDictSource):
if name not in self.states:
raise ValueError(f"{name} missing from registered states.")
self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
def register(self, name: str, source: flashy.state.StateDictSource):
if name in self.states:
raise ValueError(f"{name} already present in states.")
# Registering parameter ids for EMA and non-EMA states allows us to check that
# there is no overlap that would create ambiguity about how to handle the best state
param_ids = self._get_parameter_ids(source.state_dict())
if isinstance(source, ModuleDictEMA):
logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
self._validate_no_parameter_ids_overlap(name, param_ids)
self.param_ids[name] = param_ids
else:
logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
self._validate_no_parameter_ids_overlap('base', param_ids)
self.param_ids['base'].update(param_ids)
# Register state
self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
def state_dict(self) -> flashy.state.StateDict:
return self.states
def load_state_dict(self, state: flashy.state.StateDict):
for name, sub_state in state.items():
for k, v in sub_state.items():
self.states[name][k].copy_(v)
================================================
FILE: audiocraft/utils/cache.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ThreadPoolExecutor
from collections import deque
from functools import partial
from hashlib import sha1
import logging
from pathlib import Path
import sys
import typing as tp
import zipfile
import flashy
import torch
logger = logging.getLogger(__name__)
def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
"""Utility function for the EmbeddingCache, returning the full embedding without any chunking.
This method can be used in case there is no need in extracting a chunk of the full embedding
read from the cache.
Args:
full_embed (torch.Tensor): The full embedding.
x (any): Batch object from which the full embedding is derived.
idx (torch.Tensor): Index of object to consider in the batch object.
Returns:
full_embed (torch.Tensor): The full embedding
"""
return full_embed.to(device)
class EmbeddingCache:
"""Cache around embeddings computation for faster execution.
The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
using a user-provided function. When the cache is warm (all embeddings are pre-computed),
the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
and synchronization points in the forward calls.
Args:
cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
device (str or torch.device): Device on which the embedding is returned.
compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
the embedding from a given object and path. This user provided function can compute the
embedding from the provided object or using the provided path as entry point. The last parameter
specify the index corresponding to the current embedding in the object that can represent batch metadata.
extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
the desired embedding chunk from the full embedding loaded from the cache. The last parameter
specify the index corresponding to the current embedding in the object that can represent batch metadata.
If not specified, will return the full embedding unmodified.
"""
def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, torch.device],
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
self.cache_path = Path(cache_path)
self.device = device
self._compute_embed_fn = compute_embed_fn
self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
if extract_embed_fn is not None:
self._extract_embed_fn = extract_embed_fn
else:
self._extract_embed_fn = partial(get_full_embed, device=device)
if self.cache_path is not None:
self.cache_path.mkdir(exist_ok=True, parents=True)
logger.info(f"Cache instantiated at: {self.cache_path}")
self.pool = ThreadPoolExecutor(8)
self.pool.__enter__()
self._current_batch_cache: dict = {}
self._memory_cache: dict = {}
def _get_cache_path(self, path: tp.Union[Path, str]):
"""Get cache path for the given file path."""
sig = sha1(str(path).encode()).hexdigest()
return self.cache_path / sig
@staticmethod
def _get_full_embed_from_cache(cache: Path):
"""Loads full pre-computed embedding from the cache."""
try:
embed = torch.load(cache, 'cpu')
except Exception as exc:
logger.error("Error loading %s: %r", cache, exc)
embed = None
return embed
def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
"""Get embedding from cache, computing and storing it to cache if not already cached.
The EmbeddingCache first tries to load the embedding from the in-memory cache
containing the pre-computed chunks populated through `populate_embed_cache`.
If not found, the full embedding is computed and stored on disk to be later accessed
to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
Args:
paths (list[Path or str]): List of paths from where the embeddings can be loaded.
x (any): Object from which the embedding is extracted.
"""
embeds = []
for idx, path in enumerate(paths):
cache = self._get_cache_path(path)
if cache in self._current_batch_cache:
embed = self._current_batch_cache[cache]
else:
full_embed = self._compute_embed_fn(path, x, idx)
try:
with flashy.utils.write_and_rename(cache, pid=True) as f:
torch.save(full_embed.cpu(), f)
except Exception as exc:
logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
else:
logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
embed = self._extract_embed_fn(full_embed, x, idx)
embeds.append(embed)
embed = torch.stack(embeds, dim=0)
return embed
def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
"""Populate in-memory caches for embeddings reading from the embeddings stored on disk.
The in-memory caches consist in a cache for the full embedding and another cache for the
final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
and reduce the IO footprint and synchronization points during forward passes.
Args:
paths (list[Path]): List of paths from where the embeddings can be loaded.
x (any): Object from which the embedding is extracted.
"""
self._current_batch_cache.clear()
if self.cache_path is not None:
futures: list = []
for path in paths:
assert path is not None, "Path is required for computation from cache"
cache = self._get_cache_path(path)
if cache in self._memory_cache or not cache.exists():
futures.append(None)
else:
futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
for idx, (path, future) in enumerate(zip(paths, futures)):
assert path is not None
cache = self._get_cache_path(path)
full_embed = None
if future is None:
if cache in self._memory_cache:
full_embed = self._memory_cache[cache]
else:
full_embed = future.result()
if full_embed is not None:
self._memory_cache[cache] = full_embed
full_embed = full_embed.to(self.device)
if full_embed is not None:
embed = self._extract_embed_fn(full_embed, x, idx)
self._current_batch_cache[cache] = embed
class CachedBatchWriter:
"""Write pre computed caches for mini batches. This can
make loading a lot more efficient depending on your filesystem.
Args:
cache_folder (Path): folder in which the cached minibatches
will be stored.
Inside cache folder, the structure is the following:
`epoch_number / update_number.zip`
And the zip file contains one entry per batch item.
It is possible to use the cache with a batch size smaller than
created with but obviously not larger. Make sure to call the
`start_epoch(epoch)` method for indicating changes of epochs.
See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
for an example of how to warmup the cache.
"""
def __init__(self, cache_folder: Path):
self.cache_folder = cache_folder
self._current_epoch: tp.Optional[int] = None
self._current_index = 0
def start_epoch(self, epoch: int):
"""Call at the beginning of each epoch.
"""
self._current_epoch = epoch
self._current_index = 0
self._zip_path.parent.mkdir(exist_ok=True, parents=True)
@staticmethod
def _get_zip_path(cache_folder: Path, epoch: int, index: int):
return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
@property
def _zip_path(self):
assert self._current_epoch is not None
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
def save(self, *content):
"""Save one mini batch. This function is distributed-aware
and will automatically merge all the items from the different
workers.
"""
all_contents = []
for rank in range(flashy.distrib.world_size()):
their_content = flashy.distrib.broadcast_object(content, src=rank)
all_contents.append(their_content)
if flashy.distrib.is_rank_zero():
idx = 0
with flashy.utils.write_and_rename(self._zip_path) as tmp:
with zipfile.ZipFile(tmp, 'w') as zf:
for content in all_contents:
for vals in zip(*content):
with zf.open(f'{idx}', 'w') as f: # type: ignore
torch.save(vals, f)
idx += 1
flashy.distrib.barrier()
self._current_index += 1
class CachedBatchLoader:
"""Loader for cached mini-batches dumped with `CachedBatchWriter`.
Args:
cache_folder (Path): folder in which the cached minibatches are stored.
batch_size (int): batch size (per GPU) expected.
num_workers (int): number of workers to use for loading.
min_length (int): minimum expected length for each epoch. If some
mini-batches are missing, and error is raised.
This is iterable just like a regular DataLoader.
"""
def __init__(self, cache_folder: Path, batch_size: int,
num_workers: int = 10, min_length: int = 1):
self.cache_folder = cache_folder
self.batch_size = batch_size
self.num_workers = num_workers
self.min_length = min_length
self._current_epoch: tp.Optional[int] = None
self.sampler = None # for compatibility with the regular DataLoader
def __len__(self):
path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
return len([p for p in path.iterdir() if p.suffix == ".zip"])
def start_epoch(self, epoch: int):
"""Call at the beginning of each epoch.
"""
self._current_epoch = epoch
def _zip_path(self, index: int):
assert self._current_epoch is not None
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
def _load_one(self, index: int):
zip_path = self._zip_path(index)
if not zip_path.exists():
if index < self.min_length:
raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
return None
mode = "rb" if sys.version_info >= (3, 9) else "r"
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
rank = flashy.distrib.rank()
world_size = flashy.distrib.world_size()
root = zipfile.Path(zf)
items = list(root.iterdir())
total_batch_size = self.batch_size * world_size
if len(items) < total_batch_size:
raise RuntimeError(
f"The cache can handle a max batch size of {len(items)}, "
f"but {total_batch_size} is needed.")
start = rank * self.batch_size
items = items[start: start + self.batch_size]
assert len(items) == self.batch_size
entries = []
entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
transposed = zip(*entries)
out = []
for part in transposed:
assert len(part) > 0
if isinstance(part[0], torch.Tensor):
out.append(torch.stack(part))
else:
out.append(part)
return out
except Exception:
logger.error("Error when reading zip path %s", zip_path)
raise
def __iter__(self):
"""This will yields tuples, exactly as provided to the
`CachedBatchWriter.save` method.
"""
pool = ThreadPoolExecutor(self.num_workers)
next_index = 0
queue = deque()
def _get_next():
nonlocal next_index
r = queue.popleft().result()
if r is None:
return None
else:
queue.append(pool.submit(self._load_one, next_index))
next_index += 1
return r
with pool:
# fill the buffer of fetching jobs.
for _ in range(2 * self.num_workers):
queue.append(pool.submit(self._load_one, next_index))
next_index += 1
while True:
batch = _get_next()
if batch is None:
return
yield batch
================================================
FILE: audiocraft/utils/checkpoint.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
import logging
from pathlib import Path
import re
import typing as tp
import flashy
import torch
from ..environment import AudioCraftEnvironment
logger = logging.getLogger(__name__)
class CheckpointSource(Enum):
CURRENT_XP = "current_xp"
PRETRAINED = "pretrained"
OTHER = "other"
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
`checkpoint_.th(.)`. By convention, name is expected to be empty for last checkpoint,
'best' for the best checkpoint or the epoch number.
Args:
name (str, optional): Name suffix for the checkpoint file stem.
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
use_fsdp (bool): Whether the calling solver relies on FSDP.
Returns:
str: The checkpoint name.
"""
suffix = ''
if rank is None:
rank = flashy.distrib.rank()
if rank > 0 and use_fsdp:
suffix = '.' + str(rank)
name_part = ''
if name is not None:
name_part = f'_{name}'
return f'checkpoint{name_part}.th{suffix}'
def is_sharded_checkpoint(path: Path) -> bool:
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
return re.search(r'\.th\.\d+$', path.name) is not None
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
use_fsdp: bool = False) -> tp.Optional[Path]:
"""Resolve a given checkpoint path for a provided dora sig or path.
Args:
sig_or_path (Path or str): Checkpoint path or dora signature.
name (str, optional): Name suffix for the checkpoint file stem.
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
use_fsdp (bool): Whether the calling solver relies on FSDP.
Returns:
Path, optional: Resolved checkpoint path, if it exists.
"""
from audiocraft import train
xps_root = train.main.dora.dir / 'xps'
sig_or_path = str(sig_or_path)
if sig_or_path.startswith('//sig/'):
sig = sig_or_path[len('//sig/'):]
path = xps_root / sig
else:
path = Path(sig_or_path)
path = AudioCraftEnvironment.resolve_reference_path(path)
if path.is_dir():
path = path / checkpoint_name(name, use_fsdp=use_fsdp)
if path.exists():
return path
else:
return None
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
"""Load state from checkpoints at the specified checkpoint path."""
if is_sharded:
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
if rank0_checkpoint_path.exists():
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
state = torch.load(checkpoint_path, 'cpu')
logger.info("Checkpoint loaded from %s", checkpoint_path)
return state
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
"""Save state to disk to the specified checkpoint_path."""
_safe_save_checkpoint(state, checkpoint_path, is_sharded)
logger.info("Checkpoint saved to %s", checkpoint_path)
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
"""Flush checkpoints to only keep last N checkpoints."""
if keep_last is None or keep_last <= 0:
return
checkpoint_dir = checkpoint_path.parent
suffix = ''
if flashy.distrib.rank() > 0:
suffix = f'.{flashy.distrib.rank()}'
checkpoint_files_with_epoch = []
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
if epoch_part.isdigit():
checkpoint_files_with_epoch.append((path, int(epoch_part)))
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
total_to_flush = max(0, len(checkpoint_files) - keep_last)
files_to_flush = checkpoint_files[:total_to_flush]
for path in files_to_flush:
logger.debug("Removing checkpoint: %s", str(path))
path.unlink(missing_ok=True)
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
# Finish the work of a previous run that got interrupted while dumping.
old_path = Path(str(checkpoint_path) + '.old')
if old_path.exists():
raise RuntimeError(
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
token = Path(str(rank0_checkpoint_path) + '.tmp.done')
tmp_path = Path(str(checkpoint_path) + '.tmp')
if token.exists():
if tmp_path.exists():
tmp_path.rename(checkpoint_path)
flashy.distrib.barrier()
if flashy.distrib.is_rank_zero() and token.exists():
token.unlink()
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
def _barrier_if_sharded():
if is_sharded:
flashy.distrib.barrier()
if flashy.distrib.is_rank_zero():
token = Path(str(checkpoint_path) + '.tmp.done')
if token.exists():
token.unlink()
_barrier_if_sharded()
with flashy.utils.write_and_rename(checkpoint_path) as f:
torch.save(state, f)
_barrier_if_sharded()
if flashy.distrib.is_rank_zero():
token.touch()
_barrier_if_sharded()
_barrier_if_sharded()
if flashy.distrib.rank() == 0:
token.unlink()
================================================
FILE: audiocraft/utils/cluster.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions for SLURM configuration and cluster settings.
"""
from enum import Enum
import os
import socket
import typing as tp
import omegaconf
class ClusterType(Enum):
AWS = "aws"
FAIR = "fair"
RSC = "rsc"
LOCAL_DARWIN = "darwin"
DEFAULT = "default" # used for any other cluster.
def _guess_cluster_type() -> ClusterType:
uname = os.uname()
fqdn = socket.getfqdn()
if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
return ClusterType.AWS
if fqdn.endswith(".fair"):
return ClusterType.FAIR
if fqdn.endswith(".facebook.com"):
return ClusterType.RSC
if uname.sysname == "Darwin":
return ClusterType.LOCAL_DARWIN
return ClusterType.DEFAULT
def get_cluster_type(
cluster_type: tp.Optional[ClusterType] = None,
) -> tp.Optional[ClusterType]:
if cluster_type is None:
return _guess_cluster_type()
return cluster_type
def get_slurm_parameters(
cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
) -> omegaconf.DictConfig:
"""Update SLURM parameters in configuration based on cluster type.
If the cluster type is not specify, it infers it automatically.
"""
from ..environment import AudioCraftEnvironment
cluster_type = get_cluster_type(cluster_type)
# apply cluster-specific adjustments
if cluster_type == ClusterType.AWS:
cfg["mem_per_gpu"] = None
cfg["constraint"] = None
cfg["setup"] = []
elif cluster_type == ClusterType.RSC:
cfg["mem_per_gpu"] = None
cfg["setup"] = []
cfg["constraint"] = None
cfg["partition"] = "learn"
slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
if slurm_exclude is not None:
cfg["exclude"] = slurm_exclude
return cfg
================================================
FILE: audiocraft/utils/deadlock.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from queue import Queue, Empty
import signal
import sys
import threading
import traceback
logger = logging.getLogger(__name__)
class DeadlockDetect:
def __init__(self, use: bool = False, timeout: float = 120.):
self.use = use
self.timeout = timeout
self._queue: Queue = Queue()
def update(self, stage: str):
if self.use:
self._queue.put(stage)
def __enter__(self):
if self.use:
self._thread = threading.Thread(target=self._detector_thread)
self._thread.start()
def __exit__(self, exc_type, exc_val, exc_tb):
if self.use:
self._queue.put(None)
self._thread.join()
def _detector_thread(self):
logger.debug("Deadlock detector started")
last_stage = "init"
while True:
try:
stage = self._queue.get(timeout=self.timeout)
except Empty:
break
if stage is None:
logger.debug("Exiting deadlock detector thread")
return
else:
last_stage = stage
logger.error("Deadlock detector timed out, last stage was %s", last_stage)
for th in threading.enumerate():
print(th, file=sys.stderr)
traceback.print_stack(sys._current_frames()[th.ident])
print(file=sys.stderr)
sys.stdout.flush()
sys.stderr.flush()
os.kill(os.getpid(), signal.SIGKILL)
================================================
FILE: audiocraft/utils/export.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility to export a training checkpoint to a lightweight release checkpoint.
"""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf
import torch
from audiocraft import __version__
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
"""Export only the best state from the given EnCodec checkpoint. This
should be used if you trained your own EnCodec model.
"""
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['best_state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
'version': __version__,
'exported': True,
}
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
torch.save(new_pkg, out_file)
return out_file
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
"""Export a compression model (potentially EnCodec) from a pretrained model.
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
Do not include the //pretrained/ prefix. For instance if you trained a model
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
In that case, this will not actually include a copy of the model, simply the reference
to the model used.
"""
if Path(pretrained_encodec).exists():
pkg = torch.load(pretrained_encodec)
assert 'best_state' in pkg
assert 'xp.cfg' in pkg
assert 'version' in pkg
assert 'exported' in pkg
else:
pkg = {
'pretrained': pretrained_encodec,
'exported': True,
'version': __version__,
}
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
torch.save(pkg, out_file)
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
"""Export only the best state from the given MusicGen or AudioGen checkpoint.
"""
pkg = torch.load(checkpoint_path, 'cpu')
if pkg['fsdp_best_state']:
best_state = pkg['fsdp_best_state']['model']
else:
assert pkg['best_state']
best_state = pkg['best_state']['model']
new_pkg = {
'best_state': best_state,
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
'version': __version__,
'exported': True,
}
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
torch.save(new_pkg, out_file)
return out_file
================================================
FILE: audiocraft/utils/export_legacy.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Legacy functions used at the time of the first release, kept for referencd.
"""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf, DictConfig
import torch
def _clean_lm_cfg(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
# This used to be set automatically in the LM solver, need a more robust solution
# for the future.
cfg['transformer_lm']['card'] = 2048
cfg['transformer_lm']['n_q'] = 4
# Experimental params no longer supported.
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
for name in bad_params:
del cfg['transformer_lm'][name]
OmegaConf.set_struct(cfg, True)
return cfg
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['ema']['state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['fsdp_best_state']['model'],
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file
================================================
FILE: audiocraft/utils/notebook.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
try:
import IPython.display as ipd # type: ignore
except ImportError:
# Note in a notebook...
pass
import torch
def display_audio(samples: torch.Tensor, sample_rate: int):
"""Renders an audio player for the given audio samples.
Args:
samples (torch.Tensor): a Tensor of decoded audio samples
with shapes [B, C, T] or [C, T]
sample_rate (int): sample rate audio should be displayed with.
"""
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for audio in samples:
ipd.display(ipd.Audio(audio, rate=sample_rate))
================================================
FILE: audiocraft/utils/profiler.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import typing as tp
import dora
import torch
logger = logging.getLogger(__name__)
class Profiler:
"""Context manager wrapper for xformers profiler.
"""
def __init__(self, module: torch.nn.Module, enabled: bool = False):
self.profiler: tp.Optional[tp.Any] = None
if enabled:
from xformers.profiler import profile
output_dir = dora.get_xp().folder / 'profiler_data'
logger.info("Profiling activated, results with be saved to %s", output_dir)
self.profiler = profile(output_dir=output_dir, module=module)
def step(self):
if self.profiler is not None:
self.profiler.step() # type: ignore
def __enter__(self):
if self.profiler is not None:
return self.profiler.__enter__() # type: ignore
def __exit__(self, exc_type, exc_value, exc_tb):
if self.profiler is not None:
return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
================================================
FILE: audiocraft/utils/samples/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: audiocraft/utils/samples/manager.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
API that can manage the storage and retrieval of generated samples produced by experiments.
It offers the following benefits:
* Samples are stored in a consistent way across epoch
* Metadata about the samples can be stored and retrieved
* Can retrieve audio
* Identifiers are reliable and deterministic for prompted and conditioned samples
* Can request the samples for multiple XPs, grouped by sample identifier
* For no-input samples (not prompt and no conditions), samples across XPs are matched
by sorting their identifiers
"""
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from functools import lru_cache
import hashlib
import json
import logging
from pathlib import Path
import re
import typing as tp
import unicodedata
import uuid
import dora
import torch
from ...data.audio import audio_read, audio_write
logger = logging.getLogger(__name__)
@dataclass
class ReferenceSample:
id: str
path: str
duration: float
@dataclass
class Sample:
id: str
path: str
epoch: int
duration: float
conditioning: tp.Optional[tp.Dict[str, tp.Any]]
prompt: tp.Optional[ReferenceSample]
reference: tp.Optional[ReferenceSample]
generation_args: tp.Optional[tp.Dict[str, tp.Any]]
def __hash__(self):
return hash(self.id)
def audio(self) -> tp.Tuple[torch.Tensor, int]:
return audio_read(self.path)
def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
return audio_read(self.prompt.path) if self.prompt is not None else None
def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
return audio_read(self.reference.path) if self.reference is not None else None
class SampleManager:
"""Audio samples IO handling within a given dora xp.
The sample manager handles the dumping and loading logic for generated and
references samples across epochs for a given xp, providing a simple API to
store, retrieve and compare audio samples.
Args:
xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
where all outputs are stored and the configuration of the experiment,
which is useful to retrieve audio-related parameters.
map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
instead of generating a dedicated hash id. This is useful to allow easier comparison
with ground truth sample from the files directly without having to read the JSON metadata
to do the mapping (at the cost of potentially dumping duplicate prompts/references
depending on the task).
"""
def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
self.xp = xp
self.base_folder: Path = xp.folder / xp.cfg.generate.path
self.reference_folder = self.base_folder / 'reference'
self.map_reference_to_sample_id = map_reference_to_sample_id
self.samples: tp.List[Sample] = []
self._load_samples()
@property
def latest_epoch(self):
"""Latest epoch across all samples."""
return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
def _load_samples(self):
"""Scan the sample folder and load existing samples."""
jsons = self.base_folder.glob('**/*.json')
with ThreadPoolExecutor(6) as pool:
self.samples = list(pool.map(self._load_sample, jsons))
@staticmethod
@lru_cache(2**26)
def _load_sample(json_file: Path) -> Sample:
with open(json_file, 'r') as f:
data: tp.Dict[str, tp.Any] = json.load(f)
# fetch prompt data
prompt_data = data.get('prompt')
prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
duration=prompt_data['duration']) if prompt_data else None
# fetch reference data
reference_data = data.get('reference')
reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
duration=reference_data['duration']) if reference_data else None
# build sample object
return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
generation_args=data.get('generation_args'))
def _init_hash(self):
return hashlib.sha1()
def _get_tensor_id(self, tensor: torch.Tensor) -> str:
hash_id = self._init_hash()
hash_id.update(tensor.numpy().data)
return hash_id.hexdigest()
def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
conditions: tp.Optional[tp.Dict[str, str]]) -> str:
"""Computes an id for a sample given its input data.
This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
Args:
index (int): Batch index, Helpful to differentiate samples from the same batch.
prompt_wav (torch.Tensor): Prompt used during generation.
conditions (dict[str, str]): Conditioning used during generation.
"""
# For totally unconditioned generations we will just use a random UUID.
# The function get_samples_for_xps will do a simple ordered match with a custom key.
if prompt_wav is None and not conditions:
return f"noinput_{uuid.uuid4().hex}"
# Human readable portion
hr_label = ""
# Create a deterministic id using hashing
hash_id = self._init_hash()
hash_id.update(f"{index}".encode())
if prompt_wav is not None:
hash_id.update(prompt_wav.numpy().data)
hr_label += "_prompted"
else:
hr_label += "_unprompted"
if conditions:
encoded_json = json.dumps(conditions, sort_keys=True).encode()
hash_id.update(encoded_json)
cond_str = "-".join([f"{key}={slugify(value)}"
for key, value in sorted(conditions.items())])
cond_str = cond_str[:100] # some raw text might be too long to be a valid filename
cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
hr_label += f"_{cond_str}"
else:
hr_label += "_unconditioned"
return hash_id.hexdigest() + hr_label
def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
"""Stores the audio with the given stem path using the XP's configuration.
Args:
wav (torch.Tensor): Audio to store.
stem_path (Path): Path in sample output directory with file stem to use.
overwrite (bool): When False (default), skips storing an existing audio file.
Returns:
Path: The path at which the audio is stored.
"""
existing_paths = [
path for path in stem_path.parent.glob(stem_path.stem + '.*')
if path.suffix != '.json'
]
exists = len(existing_paths) > 0
if exists and overwrite:
logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
elif exists:
return existing_paths[0]
audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
return audio_path
def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
ground_truth_wav: tp.Optional[torch.Tensor] = None,
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
"""Adds a single sample.
The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
Each sample is assigned an id which is computed using the input data. In addition to the
sample itself, a json file containing associated metadata is stored next to it.
Args:
sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
epoch (int): current training epoch.
index (int): helpful to differentiate samples from the same batch.
conditions (dict[str, str], optional): conditioning used during generation.
prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
Tensor of shape [channels, shape].
generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
Returns:
Sample: The saved sample.
"""
sample_id = self._get_sample_id(index, prompt_wav, conditions)
reuse_id = self.map_reference_to_sample_id
prompt, ground_truth = None, None
if prompt_wav is not None:
prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
if ground_truth_wav is not None:
ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
self.samples.append(sample)
with open(sample_path.with_suffix('.json'), 'w') as f:
json.dump(asdict(sample), f, indent=2)
return sample
def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
prompt_wavs: tp.Optional[torch.Tensor] = None,
ground_truth_wavs: tp.Optional[torch.Tensor] = None,
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
"""Adds a batch of samples.
The samples are stored in the XP's sample output directory, under a corresponding
epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
In addition to the sample itself, a json file containing associated metadata is stored next to it.
Args:
sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
epoch (int): Current training epoch.
conditioning (list of dict[str, str], optional): List of conditions used during generation,
one per sample in the batch.
prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
[batch_size, channels, shape].
ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
Tensor of shape [batch_size, channels, shape].
generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
Returns:
samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
"""
samples = []
for idx, wav in enumerate(samples_wavs):
prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
conditions = conditioning[idx] if conditioning is not None else None
samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
return samples
def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
exclude_unprompted: bool = False, exclude_conditioned: bool = False,
exclude_unconditioned: bool = False) -> tp.Set[Sample]:
"""Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
Please note that existing samples are loaded during the manager's initialization, and added samples through this
manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
is the only way detect them.
Args:
epoch (int): If provided, only return samples corresponding to this epoch.
max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
exclude_prompted (bool): If True, does not include samples that used a prompt.
exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
exclude_conditioned (bool): If True, excludes samples that used conditioning.
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
Returns:
Samples (set of Sample): The retrieved samples matching the provided filters.
"""
if max_epoch >= 0:
samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
else:
samples_epoch = self.latest_epoch if epoch < 0 else epoch
samples = {
sample
for sample in self.samples
if (
(sample.epoch == samples_epoch) and
(not exclude_prompted or sample.prompt is None) and
(not exclude_unprompted or sample.prompt is not None) and
(not exclude_conditioned or not sample.conditioning) and
(not exclude_unconditioned or sample.conditioning)
)
}
return samples
def slugify(value: tp.Any, allow_unicode: bool = False):
"""Process string for safer file naming.
Taken from https://github.com/django/django/blob/master/django/utils/text.py
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
"""
value = str(value)
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = (
unicodedata.normalize("NFKD", value)
.encode("ascii", "ignore")
.decode("ascii")
)
value = re.sub(r"[^\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
# Create a dictionary of stable id -> sample per XP
stable_samples_per_xp = [{
sample.id: sample for sample in samples
if sample.prompt is not None or sample.conditioning
} for samples in samples_per_xp]
# Set of all stable ids
stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
# Dictionary of stable id -> list of samples. If an XP does not have it, assign None
stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
# Filter out ids that contain None values (we only want matched samples after all)
# cast is necessary to avoid mypy linter errors.
return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
# For unstable ids, we use a sorted list since we'll match them in order
unstable_samples_per_xp = [[
sample for sample in sorted(samples, key=lambda x: x.id)
if sample.prompt is None and not sample.conditioning
] for samples in samples_per_xp]
# Trim samples per xp so all samples can have a match
min_len = min([len(samples) for samples in unstable_samples_per_xp])
unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
# Dictionary of index -> list of matched samples
return {
f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
}
def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
"""Gets a dictionary of matched samples across the given XPs.
Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
will always match the number of XPs provided and will correspond to each XP in the same order given.
In other words, only samples that can be match across all provided XPs will be returned
in order to satisfy this rule.
There are two types of ids that can be returned: stable and unstable.
* Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
(prompts/conditioning). This is why we can match them across XPs.
* Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
that used non-deterministic, random ids. This is the case for samples that did not use prompts or
conditioning for their generation. This function will sort these samples by their id and match them
by their index.
Args:
xps: a list of XPs to match samples from.
start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
end_epoch (int): If provided, only return samples corresponding to this epoch or older.
exclude_prompted (bool): If True, does not include samples that used a prompt.
exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
exclude_conditioned (bool): If True, excludes samples that used conditioning.
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
"""
managers = [SampleManager(xp) for xp in xps]
samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
stable_samples = _match_stable_samples(samples_per_xp)
unstable_samples = _match_unstable_samples(samples_per_xp)
return dict(stable_samples, **unstable_samples)
================================================
FILE: audiocraft/utils/ui.py
================================================
from pathlib import Path
import gradio as gr
import torch
refresh_symbol = '\U0001f504' # 🔄
class ToolButton(gr.Button, gr.components.IOComponent):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_block_name(self):
return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = ToolButton(value=refresh_symbol, elem_classes=elem_class, scale=1, size="sm", container=False)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button
================================================
FILE: audiocraft/utils/utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ProcessPoolExecutor
from contextlib import contextmanager
from functools import wraps, lru_cache
import hashlib
import json
import logging
from pathlib import Path
import typing as tp
import flashy
import flashy.distrib
import omegaconf
import torch
from torch.nn.utils.rnn import pad_sequence
logger = logging.getLogger(__name__)
def model_hash(model: torch.nn.Module) -> str:
"""Return a model hash. This should allow us to track regressions in model init
from the logs of past experiments.
"""
hasher = hashlib.sha1()
for p in model.parameters():
hasher.update(p.data.cpu().numpy().tobytes())
return hasher.hexdigest()
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
"""Convenience function to map an omegaconf configuration to a dictionary.
Args:
cfg (omegaconf.DictConfig): Original configuration to map to dict.
Returns:
dict: Config as dictionary object.
"""
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
assert isinstance(dct, dict)
return dct
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
if max_samples >= len(dataset):
return dataset
generator = torch.Generator().manual_seed(seed)
perm = torch.randperm(len(dataset), generator=generator)
return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
"""Convenience function to load dataset into a dataloader with optional subset sampling.
Args:
dataset: Dataset to load.
num_samples (Optional[int]): Number of samples to limit subset size.
batch_size (int): Batch size.
num_workers (int): Number of workers for data loading.
seed (int): Random seed.
"""
if num_samples is not None:
dataset = random_subset(dataset, num_samples, seed)
dataloader = flashy.distrib.loader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
**kwargs
)
return dataloader
def get_dataset_from_loader(dataloader):
dataset = dataloader.dataset
if isinstance(dataset, torch.utils.data.Subset):
return dataset.dataset
else:
return dataset
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in “top-k”.
Returns:
torch.Tensor: Sampled tokens.
"""
top_k_value, _ = torch.topk(probs, k, dim=-1)
min_value_top_k = top_k_value[..., [-1]]
probs *= (probs >= min_value_top_k).float()
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in “top-p”.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class DummyPoolExecutor:
"""Dummy pool executor to use when we actually have only 1 worker.
(e.g. instead of ProcessPoolExecutor).
"""
class DummyResult:
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def result(self):
return self.func(*self.args, **self.kwargs)
def __init__(self, workers, mp_context=None):
pass
def submit(self, func, *args, **kwargs):
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
return
def get_pool_executor(num_workers: int, mp_context=None):
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
Args:
lengths (torch.Tensor): tensor with lengths
max_len (int): can set the max length manually. Defaults to None.
Returns:
torch.Tensor: mask with 0s where there is pad tokens else 1s
"""
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
final_length = lengths.max().item() if not max_len else max_len
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
def hash_trick(word: str, vocab_size: int) -> int:
"""Hash trick to pair each word with an index
Args:
word (str): word we wish to convert to an index
vocab_size (int): size of the vocabulary
Returns:
int: index of the word in the embedding LUT
"""
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
return hash % vocab_size
def with_rank_rng(base_seed: int = 1234):
"""Decorator for a function so that the function will use a Random Number Generator
whose state depend on the GPU rank. The original RNG state is restored upon returning.
Args:
base_seed (int): Random seed.
"""
def _decorator(fun: tp.Callable):
@wraps(fun)
def _decorated(*args, **kwargs):
state = torch.get_rng_state()
seed = base_seed ^ flashy.distrib.rank()
torch.manual_seed(seed)
logger.debug('Rank dependent seed set to %d', seed)
try:
return fun(*args, **kwargs)
finally:
torch.set_rng_state(state)
logger.debug('RNG state restored.')
return _decorated
return _decorator
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
- `dim` specifies the time dimension which will be stacked and padded.
- The output will contain 1 new dimension (dimension index 0) which will be the size of
of the original list.
Args:
tensors (tp.List[torch.Tensor]): List of tensors to collate.
dim (int): Dimension which will be stacked and padded.
Returns:
tp.Tuple[torch.Tensor, torch.Tensor]:
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
(dimension index 0) which will be the size of the original list.
torch.Tensor: Tensor containing length of original tensor sizes (without padding).
"""
tensors = [x.transpose(0, dim) for x in tensors]
lens = torch.LongTensor([len(x) for x in tensors])
padded_tensors = pad_sequence(tensors)
padded_tensors = padded_tensors.transpose(0, 1)
padded_tensors = padded_tensors.transpose(1, dim + 1)
return padded_tensors, lens
# TODO: Move to flashy?
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
if isinstance(state, torch.Tensor):
if dtype is None or not state.is_floating_point():
dtype = state.dtype
return state.detach().to(device=device, dtype=dtype, copy=True)
elif isinstance(state, dict):
return {k: copy_state(v, device, dtype) for k, v in state.items()}
elif isinstance(state, list):
return [copy_state(v, device, dtype) for v in state]
# TODO: Move to flashy?
@contextmanager
def swap_state(model, state, **kwargs):
old_state = copy_state(model.state_dict())
model.load_state_dict(state, **kwargs)
try:
yield
finally:
model.load_state_dict(old_state)
@lru_cache(None)
def warn_once(logger, msg):
"""Warn about a given message only once."""
logger.warning(msg)
def is_jsonable(x: tp.Any):
"""Check if an object can be serialized into a json:"""
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
"""Wrapper around state dict loading of CLAP model
addressing compatibility issues between CLAP and AudioCraft
HuggingFace transformer version.
See: https://github.com/LAION-AI/CLAP/issues/118
"""
from clap_module.factory import load_state_dict # type: ignore
pkg = load_state_dict(path)
pkg.pop('text_branch.embeddings.position_ids', None)
clap_model.model.load_state_dict(pkg)
================================================
FILE: config/conditioner/chroma2music.yaml
================================================
# @package __global__
classifier_free_guidance:
training_dropout: 0.2
inference_coef: 3.0
attribute_dropout:
args:
active_on_eval: false
text: {}
wav:
self_wav: 0.5
fuser:
cross_attention_pos_emb: false
cross_attention_pos_emb_scale: 1
sum: []
prepend: [self_wav, description]
cross: []
input_interpolate: []
conditioners:
self_wav:
model: chroma_stem
chroma_stem:
sample_rate: ${sample_rate}
n_chroma: 12
radix2_exp: 14
argmax: true
match_len_on_eval: false
eval_wavs: null
n_eval_wavs: 100
cache_path: null
description:
model: t5
t5:
name: t5-base
finetune: false
word_dropout: 0.2
normalize_text: false
dataset:
train:
merge_text_p: 0.25
drop_desc_p: 0.5
drop_other_p: 0.5
================================================
FILE: config/conditioner/clapemb2music.yaml
================================================
# @package __global__
classifier_free_guidance:
training_dropout: 0.3
inference_coef: 3.0
attribute_dropout:
text: {}
wav: {}
fuser:
cross_attention_pos_emb: false
cross_attention_pos_emb_scale: 1
sum: []
prepend: []
cross: [description]
input_interpolate: []
conditioners:
description:
model: clap
clap:
checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
model_arch: 'HTSAT-base'
enable_fusion: false
sample_rate: 44100
max_audio_length: 10
audio_stride: 1
dim: 512
attribute: description
normalize: true
quantize: true # use RVQ quantization
n_q: 12
bins: 1024
kmeans_iters: 50
text_p: 0. # probability of using text embed at train time
cache_path: null
dataset:
joint_embed_attributes: [description]
train:
merge_text_p: 0.25
drop_desc_p: 0.5
drop_other_p: 0.5
================================================
FILE: config/conditioner/none.yaml
================================================
# @package __global__
# No conditioning
classifier_free_guidance:
training_dropout: 0
inference_coef: 1
attribute_dropout:
text: {}
wav: {}
fuser:
sum: []
prepend: []
cross: []
input_interpolate: []
conditioners: null
================================================
FILE: config/conditioner/text2music.yaml
================================================
# @package __global__
classifier_free_guidance:
training_dropout: 0.3
inference_coef: 3.0
attribute_dropout: {}
fuser:
cross_attention_pos_emb: false
cross_attention_pos_emb_scale: 1
sum: []
prepend: []
cross: [description]
input_interpolate: []
conditioners:
description:
model: t5
t5:
name: t5-base
finetune: false
word_dropout: 0.3
normalize_text: false
dataset:
train:
merge_text_p: 0.25
drop_desc_p: 0.5
drop_other_p: 0.5
================================================
FILE: config/conditioner/text2sound.yaml
================================================
# @package __global__
classifier_free_guidance:
training_dropout: 0.1
inference_coef: 3.0
attribute_dropout: {}
fuser:
cross_attention_pos_emb: false
cross_attention_pos_emb_scale: 1
sum: []
prepend: []
cross: [description]
input_interpolate: []
conditioners:
description:
model: t5
t5:
name: t5-large
finetune: false
word_dropout: 0.
normalize_text: false
================================================
FILE: config/config.yaml
================================================
# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft
# Please don't update this file directly. Instead use distinct configuration files
# to override the below configuration.
defaults:
- _self_
- dset: default
- solver: default
device: cuda
dtype: float32
autocast: false
autocast_dtype: bfloat16
seed: 2036
show: false # just show the model and its size and exit
continue_from: # continue from a given sig or path
execute_only: # can be set to generate/evaluate/valid to run that stage
execute_inplace: false # don't enforce continue_from to be set
# to enable inplace execution of the stage. This assume
# that you know what you are doing and execute stage
# preserving the original xp sig.
benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them
efficient_attention_backend: torch # can be torch or xformers.
num_threads: 1 # called with torch.set_num_thread.
mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server).
label: # use this if you want twice the same exp, with a name.
# logging parameters
logging:
level: INFO
log_updates: 10
log_tensorboard: false
log_wandb: false
tensorboard:
with_media_logging: false
name: # optional name for the experiment
sub_dir: # optional sub directory to store tensorboard data
wandb:
with_media_logging: true
project: # project name
name: # optional name for the experiment
group: # optional group
# SLURM launcher configuration.
slurm:
gpus: 4 # convenience parameter, number of GPUs to use.
mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`.
time: 3600
constraint:
partition:
comment:
setup: []
exclude: ''
# dora parameters
dora:
# Output folder for all artifacts of an experiment.
dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
# The following entries will be ignored by dora when computing the unique XP signature.
# Note that slurm.* and dora.* are automatically ignored.
exclude: [
'device', 'wandb.*', 'tensorboard.*', 'logging.*',
'dataset.num_workers', 'eval.num_workers', 'special.*',
'metrics.visqol.bin', 'metrics.fad.bin',
'execute_only', 'execute_best', 'generate.every',
'optim.eager_sync', 'profiler.*', 'deadlock.*',
'efficient_attention_backend', 'num_threads', 'mp_start_method',
]
use_rendezvous: false
# for grids, always run from a clean repo, allowing reliable runs and storing
# the exact commit. Your repo must be absolutely pristine clean.
# Local `dora run` are not impacted for easier debugging.
git_save: true
================================================
FILE: config/dset/audio/audiocaps_16khz.yaml
================================================
# @package __global__
# AudioCaps dataset
datasource:
max_sample_rate: 16000
max_channels: 1
train: null # only evaluation set
valid: null # only evaluation set
evaluate: egs/audiocaps/audiocaps_16khz
generate: egs/audiocaps/audiocaps_16khz # identical to evaluate
================================================
FILE: config/dset/audio/default.yaml
================================================
# @package __global__
datasource:
max_sample_rate: ???
max_channels: ???
train: ???
valid: ???
evaluate: ???
generate: null
================================================
FILE: config/dset/audio/example.yaml
================================================
# @package __global__
datasource:
max_sample_rate: 44100
max_channels: 2
train: egs/example
valid: egs/example
evaluate: egs/example
generate: egs/example
================================================
FILE: config/dset/audio/musiccaps_32khz.yaml
================================================
# @package __global__
# total samples obtained from MusicCaps = 5469
# (out of 5521 due to AudioSet corrupted samples)
datasource:
max_sample_rate: 32000
max_channels: 2
train: null # only evaluation set
valid: null # only evaluation set
evaluate: egs/musiccaps/musiccaps_32khz
generate: egs/musiccaps/musiccaps_32khz # identical to evaluate
================================================
FILE: config/dset/default.yaml
================================================
# @package __global__
# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft
# Please don't update this file directly. Instead use distinct configuration files
# to override the below configuration.
datasource:
train: ???
valid: ???
evaluate: ???
generate: ???
================================================
FILE: config/dset/internal/music_10k_32khz.yaml
================================================
# @package __global__
# high quality music dataset with no artist overlap between splits
datasource:
max_sample_rate: 32000
max_channels: 1
train: egs/music/music_10k_32khz/train
valid: egs/music/music_10k_32khz/valid
evaluate: egs/music/music_10k_32khz/test
generate: egs/music/music_10k_32khz/test # identical to evaluate
================================================
FILE: config/dset/internal/music_400k_32khz.yaml
================================================
# @package __global__
datasource:
max_sample_rate: 32000
max_channels: 1
train: egs/music/music_400k_32khz/train
valid: egs/music/music_400k_32khz/valid
evaluate: egs/music/music_400k_32khz/test
generate: egs/music/music_400k_32khz/test # identical to evaluate
================================================
FILE: config/dset/internal/sounds_16khz.yaml
================================================
# @package __global__
# environmental sounds dataset compiling all datasets
# with applied filters on tags
datasource:
max_sample_rate: 16000
max_channels: 1
train: egs/sound/sounds_16khz/train
valid: egs/sound/sounds_16khz/valid
evaluate: egs/sound/sounds_16khz/test
generate: egs/sound/sounds_16khz/test # identical to evaluate
================================================
FILE: config/model/encodec/default.yaml
================================================
# @package __global__
compression_model: encodec
encodec:
autoencoder: seanet
quantizer: rvq
sample_rate: ${sample_rate}
channels: ${channels}
causal: false
renormalize: false
seanet:
dimension: 128
channels: ${channels}
causal: ${encodec.causal}
n_filters: 32
n_residual_layers: 1
ratios: [8, 5, 4, 2]
activation: ELU
activation_params: {"alpha": 1.}
norm: weight_norm
norm_params: {}
kernel_size: 7
residual_kernel_size: 3
last_kernel_size: 7
dilation_base: 2
pad_mode: constant
true_skip: true
compress: 2
lstm: 2
disable_norm_outer_blocks: 0
# Specific encoder or decoder params.
# You can also override any param for the encoder or decoder only
# by using Hydra `+param=` syntax, i.e.`
# `+seanet.decoder.n_filters=64`.
decoder:
trim_right_ratio: 1.0
final_activation: null
final_activation_params: null
encoder: {}
rvq:
n_q: 8
q_dropout: false
bins: 1024
decay: 0.99
kmeans_init: true
kmeans_iters: 50
threshold_ema_dead_code: 2
orthogonal_reg_weight: 0.0
orthogonal_reg_active_codes_only: false
no_quant: {}
================================================
FILE: config/model/encodec/encodec_base_causal.yaml
================================================
# @package __global__
defaults:
- encodec/default
encodec:
causal: true
rvq:
n_q: 32
q_dropout: true
================================================
FILE: config/model/encodec/encodec_large_nq4_s320.yaml
================================================
# @package __global__
defaults:
- encodec/default
seanet:
# default ratios are [8, 5, 4, 2]
n_filters: 64
rvq:
bins: 2048
n_q: 4
q_dropout: false
================================================
FILE: config/model/encodec/encodec_large_nq4_s640.yaml
================================================
# @package __global__
defaults:
- encodec/default
seanet:
ratios: [8, 5, 4, 4]
n_filters: 64
rvq:
bins: 2048
n_q: 4
q_dropout: false
================================================
FILE: config/model/lm/audiogen_lm.yaml
================================================
# @package __global__
defaults:
- lm/default
- override /conditioner: text2sound
- override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly
lm_model: transformer_lm
codebooks_pattern:
modeling: delay
delay:
delays: [0, 1, 2, 3]
flatten_first: 0
empty_initial: 0
unroll:
flattening: [0, 1, 2, 3]
delays: [0, 0, 0, 0]
music_lm:
group_by: 2
valle:
delays: [0, 0, 0]
transformer_lm:
n_q: 4
card: 2048
memory_efficient: true
bias_proj: false
bias_ff: false
bias_attn: false
norm_first: true
layer_scale: null
weight_init: gaussian
depthwise_init: current
zero_bias_init: true
attention_as_float32: false
================================================
FILE: config/model/lm/default.yaml
================================================
# @package __global__
defaults:
- _self_
- /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly
lm_model: transformer_lm
codebooks_pattern:
modeling: parallel
transformer_lm:
dim: 512
num_heads: 8
num_layers: 8
hidden_scale: 4
n_q: 8 # number of streams to model
card: 1024
dropout: 0.
emb_lr: null
activation: gelu
norm_first: false # use pre-norm instead of post-norm
bias_ff: true # use bias for the feedforward
bias_attn: true # use bias for the attention
bias_proj: true # use bias for the output projections
past_context: null
causal: true
custom: false # use custom MHA implementation
memory_efficient: false # use flash attention
attention_as_float32: false # use float32 for the attention part,
# recommended at the moment when memory_efficient is True.
layer_scale: null
positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope).
xpos: false # apply xpos decay (rope only).
checkpointing: none # layer checkpointing method, can be none, torch, xformers_default.
# torch is the slowest but uses the least memory,
# xformers_default is somewhere in between.
weight_init: null # weight initialization (null, gaussian or uniform)
depthwise_init: null # perform depthwise initialization (null, current, global)
zero_bias_init: false # initialize bias to zero if bias in linears and
# if a weight_init method is used.
norm: layer_norm # normalization method to use in transformer.
cross_attention: false
qk_layer_norm: false
qk_layer_norm_cross: false
attention_dropout: null
kv_repeat: 1
two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not...
================================================
FILE: config/model/lm/model_scale/base.yaml
================================================
# @package __global__
# overrides nothing because default is already transformer base (~ 60M params)
================================================
FILE: config/model/lm/model_scale/large.yaml
================================================
# @package _global_
# gpt2 inspired, even bigger (~3.3B params)
transformer_lm:
dim: 2048
num_heads: 32
num_layers: 48
================================================
FILE: config/model/lm/model_scale/medium.yaml
================================================
# @package _global_
# gpt2 like (~1.5B params)
transformer_lm:
dim: 1536
num_heads: 24
num_layers: 48
================================================
FILE: config/model/lm/model_scale/small.yaml
================================================
# @package _global_
# 300M Param.
transformer_lm:
dim: 1024
num_heads: 16
num_layers: 24
================================================
FILE: config/model/lm/model_scale/xsmall.yaml
================================================
# @package _global_
# just used for debugging or when we just want to populate the cache
# and do not care about training.
transformer_lm:
dim: 64
num_heads: 2
num_layers: 2
================================================
FILE: config/model/lm/musicgen_lm.yaml
================================================
# @package __global__
defaults:
- lm/default
- override /conditioner: text2music
- override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly
lm_model: transformer_lm
codebooks_pattern:
modeling: delay
delay:
delays: [0, 1, 2, 3]
flatten_first: 0
empty_initial: 0
unroll:
flattening: [0, 1, 2, 3]
delays: [0, 0, 0, 0]
music_lm:
group_by: 2
valle:
delays: [0, 0, 0]
transformer_lm:
n_q: 4
card: 2048
memory_efficient: true
bias_proj: false
bias_ff: false
bias_attn: false
norm_first: true
layer_scale: null
weight_init: gaussian
depthwise_init: current
zero_bias_init: true
attention_as_float32: false
================================================
FILE: config/model/none.yaml
================================================
# @package __global__
# This file exist so that model is recognized as a config group
# by Hydra, and Dora. A bit weird we might need a better fix someday.
================================================
FILE: config/model/score/basic.yaml
================================================
# @package _global_
diffusion_unet:
hidden: 48
depth: 4
res_blocks: 1
norm_groups: 4
kernel: 8
stride: 4
growth: 4
max_channels: 10_000
dropout: 0.
emb_all_layers: true
bilstm: false
codec_dim: null
transformer: false
cross_attention: false
================================================
FILE: config/solver/audiogen/audiogen_base_16khz.yaml
================================================
# @package __global__
# This is the training loop solver
# for the base AudioGen model (text-to-sound)
# on monophonic audio sampled at 16 kHz
# using a similar EnCodec+LM setup to MusicGen
defaults:
- audiogen/default
- /model: lm/audiogen_lm
- override /dset: audio/default
- _self_
autocast: true
autocast_dtype: float16
# EnCodec large trained on mono-channel music audio sampled at 16khz
# with a total stride of 320 leading to 50 frames/s.
# rvq.n_q=4, rvq.bins=2048, no quantization dropout
# (transformer_lm card and n_q must be compatible)
compression_model_checkpoint: //reference/bd44a852/checkpoint.th
channels: 1
sample_rate: 16000
deadlock:
use: true # deadlock detection
dataset:
batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128)
num_workers: 10
segment_duration: 10
min_segment_ratio: 1.0
sample_on_weight: false # Uniform sampling all the way
sample_on_duration: false # Uniform sampling all the way
external_metadata_source: null
# sample mixing augmentation at train time
train:
batch_size: 256 # matching AudioGen paper setup
aug_p: 0.5 # perform audio mixing 50% of the time
mix_p: 0.5 # proportion of batch items mixed together
# important: note that this will reduce the
# actual batch size used at train time
# which will be equal to mix_p * batch_size
mix_snr_low: -5
mix_snr_high: 5
mix_min_overlap: 0.5
generate:
lm:
use_sampling: true
top_k: 250
top_p: 0.0
optim:
epochs: 100
optimizer: adamw
lr: 5e-4
ema:
use: true
updates: 10
device: cuda
logging:
log_tensorboard: true
schedule:
lr_scheduler: inverse_sqrt
inverse_sqrt:
warmup: 3000
warmup_init_lr: 0.0
================================================
FILE: config/solver/audiogen/debug.yaml
================================================
# @package __global__
# This is a minimal debugging configuration
# for MusicGen training solver
defaults:
- audiogen/default
- /model: lm/audiogen_lm
- override /model/lm/model_scale: xsmall
- override /dset: audio/example
- _self_
autocast: false
compression_model_checkpoint: null
codebooks_pattern:
modeling: parallel
channels: 1
sample_rate: 16000
deadlock:
use: false # deadlock detection
dataset:
batch_size: 4
segment_duration: 5
sample_on_weight: false # Uniform sampling all the way
sample_on_duration: false # Uniform sampling all the way
generate:
audio:
strategy: peak
lm:
use_sampling: false
top_k: 0
top_p: 0.0
checkpoint:
save_every: 0
keep_last: 0
optim:
epochs: 2
updates_per_epoch: 10
optimizer: adamw
lr: 1e-4
logging:
log_tensorboard: true
schedule:
lr_scheduler: null
================================================
FILE: config/solver/audiogen/default.yaml
================================================
# @package __global__
defaults:
- /solver/musicgen/default
- _self_
- /solver/audiogen/evaluation: none
- override /dset: audio/default
# See config/solver/musicgen/default.yaml for a list of possible values.
# We only keep the most important here.
autocast: true
autocast_dtype: float16
solver: audiogen
sample_rate: ???
channels: ???
compression_model_checkpoint: ???
tokens:
padding_with_special_token: false
dataset:
batch_size: 128
segment_duration: 10
min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence.
optim:
epochs: 100
updates_per_epoch: 2000
lr: 1e-4
optimizer: adamw
max_norm: 1.0
adam:
betas: [0.9, 0.95]
weight_decay: 0.1
eps: 1e-8
schedule:
lr_scheduler: null
================================================
FILE: config/solver/audiogen/evaluation/none.yaml
================================================
# @package __global__
dataset:
evaluate:
num_samples: 10000
================================================
FILE: config/solver/audiogen/evaluation/objective_eval.yaml
================================================
# @package __global__
# Setup for execute only on audiocaps for audio generation
# evaluation with objective metrics
# execute_only=evaluate
dataset:
max_audio_duration: null
# ensure the proper values are broadcasted here for evaluate
evaluate:
min_audio_duration: 1. # some metrics requires a minimum audio length
max_audio_duration: null # all samples from audiocaps should be ~10s
num_samples: null
segment_duration: null
generate:
min_audio_duration: 1.
max_audio_duration: null
num_samples: 500
evaluate:
metrics:
fad: true
kld: true
text_consistency: true
metrics:
kld:
passt:
pretrained_length: 10 # similarly to reported results in AudioGen paper
================================================
FILE: config/solver/compression/debug.yaml
================================================
# @package __global__
defaults:
- compression/default
- /model: encodec/encodec_base_causal
- override /dset: audio/example
- _self_
channels: 1
sample_rate: 16000
# debug config uses just L1
losses:
adv: 0.
feat: 0.
l1: 1.
mel: 0.
msspec: 0.
# no balancer
balancer:
balance_grads: false
ema_decay: 1.
total_norm: 1.
per_batch_item: false
# no adversaries
adversarial:
adversaries: []
adv_loss: hinge
feat_loss: l1
# faster model for local dev
seanet:
dimension: 16
n_filters: 4
# very small dataset
dataset:
batch_size: 8
num_workers: 10
num_samples: 100
segment_duration: 1
evaluate:
batch_size: 32
generate:
batch_size: 1
num_samples: 5
segment_duration: 10
# limited training
evaluate:
every: 5
generate:
every: 5
optim:
epochs: 50
================================================
FILE: config/solver/compression/default.yaml
================================================
# @package __global__
defaults:
- ../default
- override /dset: audio/default
- _self_
solver: compression
sample_rate: ???
channels: ???
# loss balancing
losses:
adv: 4.
feat: 4.
l1: 0.1
mel: 0.
msspec: 2.
sisnr: 0.
balancer:
balance_grads: true
ema_decay: 0.999
per_batch_item: true
total_norm: 1.
adversarial:
every: 1
adversaries: [msstftd]
adv_loss: hinge
feat_loss: l1
# losses hyperparameters
l1: {}
l2: {}
mrstft:
factor_sc: .5
factor_mag: .5
normalized: false
mel:
sample_rate: ${sample_rate}
n_fft: 1024
hop_length: 256
win_length: 1024
n_mels: 64
f_min: 64
f_max: null
normalized: false
floor_level: 1e-5
sisnr:
sample_rate: ${sample_rate}
segment: 5.
msspec:
sample_rate: ${sample_rate}
range_start: 6
range_end: 11
n_mels: 64
f_min: 64
f_max: null
normalized: true
alphas: false
floor_level: 1e-5
# metrics
metrics:
visqol:
mode: audio
bin: null # path to visqol install
model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3
# adversaries hyperparameters
msstftd:
in_channels: 1
out_channels: 1
filters: 32
norm: weight_norm
n_ffts: [1024, 2048, 512, 256, 128]
hop_lengths: [256, 512, 128, 64, 32]
win_lengths: [1024, 2048, 512, 256, 128]
activation: LeakyReLU
activation_params: {negative_slope: 0.3}
msd:
in_channels: 1
out_channels: 1
scale_norms: [spectral_norm, weight_norm, weight_norm]
kernel_sizes: [5, 3]
filters: 16
max_filters: 1024
downsample_scales: [4, 4, 4, 4]
inner_kernel_sizes: null
groups: [4, 4, 4, 4]
strides: null
paddings: null
activation: LeakyReLU
activation_params: {negative_slope: 0.3}
mpd:
in_channels: 1
out_channels: 1
periods: [2, 3, 5, 7, 11]
n_layers: 5
kernel_size: 5
stride: 3
filters: 8
filter_scales: 4
max_filters: 1024
activation: LeakyReLU
activation_params: {negative_slope: 0.3}
norm: weight_norm
# data hyperparameters
dataset:
batch_size: 64
num_workers: 10
segment_duration: 1
train:
num_samples: 500000
valid:
num_samples: 10000
evaluate:
batch_size: 32
num_samples: 10000
generate:
batch_size: 32
num_samples: 50
segment_duration: 10
# solver hyperparameters
evaluate:
every: 25
num_workers: 5
metrics:
visqol: false
sisnr: true
generate:
every: 25
num_workers: 5
audio:
sample_rate: ${sample_rate}
# checkpointing schedule
checkpoint:
save_last: true
save_every: 25
keep_last: 10
keep_every_states: null
# optimization hyperparameters
optim:
epochs: 200
updates_per_epoch: 2000
lr: 3e-4
max_norm: 0.
optimizer: adam
adam:
betas: [0.5, 0.9]
weight_decay: 0.
ema:
use: true # whether to use EMA or not
updates: 1 # update at every step
device: ${device} # device for EMA, can be put on GPU if more frequent updates
decay: 0.99 # EMA decay value, if null, no EMA is used
================================================
FILE: config/solver/compression/encodec_audiogen_16khz.yaml
================================================
# @package __global__
defaults:
- compression/default
- /model: encodec/encodec_large_nq4_s320
- override /dset: audio/default
- _self_
channels: 1
sample_rate: 16000
================================================
FILE: config/solver/compression/encodec_base_24khz.yaml
================================================
# @package __global__
defaults:
- compression/default
- /model: encodec/encodec_base_causal
- override /dset: audio/default
- _self_
channels: 1
sample_rate: 24000
================================================
FILE: config/solver/compression/encodec_musicgen_32khz.yaml
================================================
# @package __global__
defaults:
- compression/default
- /model: encodec/encodec_large_nq4_s640
- override /dset: audio/default
- _self_
channels: 1
sample_rate: 32000
================================================
FILE: config/solver/default.yaml
================================================
# @package __global__
# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft
# Please don't update this file directly. Instead use distinct configuration files
# to override the below configuration.
solver: ???
fsdp:
use: false # should we use FSDP.
param_dtype: float16 # equivalent to autocast_dtype for FSDP.
reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability.
buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it.
sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard.
# full_shard will use less memory but slower ??
per_block: true # If True, uses nested FSDP.
profiler:
enabled: false
deadlock:
use: false
timeout: 600
dataset:
batch_size: ???
num_workers: 10
segment_duration: null
num_samples: null
return_info: false
shuffle: false
sample_on_duration: true
sample_on_weight: true
min_segment_ratio: 0.5
train:
num_samples: null
shuffle: true
shuffle_seed: 0 # if you want to sample the data differently.
permutation_on_files: false
valid:
num_samples: null
evaluate:
num_samples: null
generate:
num_samples: null
return_info: true
checkpoint:
save_last: true
save_every: null
keep_last: null
keep_every_states: null
generate:
every: null
path: 'samples'
audio:
format: 'mp3'
strategy: 'clip'
sample_rate: null
lm:
use_sampling: false
temp: 1.0
top_k: 0
top_p: 0.0
evaluate:
every: null
num_workers: 5
truncate_audio: null
fixed_generation_duration: null # in secs
metrics:
base: true # run default evaluation (e.g. like train/valid stage)
optim:
epochs: ???
updates_per_epoch: null
lr: ???
optimizer: ???
adam:
betas: [0.9, 0.999]
weight_decay: 0.
ema:
use: false # whether to use EMA or not
updates: ${optim.updates_per_epoch} # frequency of updates of the EMA
device: cpu # device for EMA, can be put on GPU if more frequent updates
decay: 0.99 # EMA decay value, if null, no EMA is used
schedule:
lr_scheduler: null
step:
step_size: null
gamma: null
exponential:
lr_decay: null
cosine:
warmup: null
lr_min_ratio: 0.0
cycle_length: 1.0
polynomial_decay:
warmup: null
zero_lr_warmup_steps: 0
end_lr: 0.0
power: 1
inverse_sqrt:
warmup: null
warmup_init_lr: 0.0
linear_warmup:
warmup: null
warmup_init_lr: 0.0
================================================
FILE: config/solver/diffusion/debug.yaml
================================================
# @package __global__
defaults:
- /solver/default
- /model: score/basic
- override /dset: audio/default
- _self_
solver: diffusion
sample_rate: 16000
channels: 1
compression_model_checkpoint: //sig/5091833e
n_q: 2 # number of codebooks to keep
dataset:
batch_size: 8
num_workers: 10
segment_duration: 1
train:
num_samples: 100
valid:
num_samples: 100
evaluate:
batch_size: 8
num_samples: 10
generate:
batch_size: 8
num_samples: 10
segment_duration: 10
loss:
kind: mse
norm_power: 0.
valid:
every: 1
evaluate:
every: 5
num_workers: 5
metrics:
visqol: false
sisnr: false
rvm: true
generate:
every: 5
num_workers: 5
audio:
sample_rate: ${sample_rate}
checkpoint:
save_last: true
save_every: 25
keep_last: 10
keep_every_states: null
optim:
epochs: 50
updates_per_epoch: 2000
lr: 2e-4
max_norm: 0
optimizer: adam
adam:
betas: [0.9, 0.999]
weight_decay: 0.
ema:
use: true # whether to use EMA or not
updates: 1 # update at every step
device: ${device} # device for EMA, can be put on GPU if more frequent updates
decay: 0.99 # EMA decay value, if null, no EMA is used
processor:
name: multi_band_processor
use: false
n_bands: 8
num_samples: 10_000
power_std: 1.
resampling:
use: false
target_sr: 16000
filter:
use: false
n_bands: 4
idx_band: 0
cutoffs: null
schedule:
repartition: "power"
variable_step_batch: true
beta_t0: 1.0e-5
beta_t1: 2.9e-2
beta_exp: 7.5
num_steps: 1000
variance: 'beta'
clip: 5.
rescale: 1.
n_bands: null
noise_scale: 1.0
metrics:
num_stage: 4
================================================
FILE: config/solver/diffusion/default.yaml
================================================
# @package __global__
defaults:
- /solver/default
- /model: score/basic
- override /dset: audio/default
- _self_
solver: diffusion
sample_rate: ???
channels: ???
compression_model_checkpoint: ???
n_q: ??? # number of codebooks to keep
dataset:
batch_size: 128
num_workers: 10
segment_duration: 1
train:
num_samples: 500000
valid:
num_samples: 10000
evaluate:
batch_size: 16
num_samples: 10000
generate:
batch_size: 32
num_samples: 50
segment_duration: 10
audio:
sample_rate: ${sample_rate}
loss:
kind: mse
norm_power: 0.
valid:
every: 1
evaluate:
every: 20
num_workers: 5
metrics:
visqol: false
sisnr: false
rvm: true
generate:
every: 25
num_workers: 5
checkpoint:
save_last: true
save_every: 25
keep_last: 10
keep_every_states: null
optim:
epochs: 20000
updates_per_epoch: 2000
lr: 2e-4
max_norm: 0
optimizer: adam
adam:
betas: [0.9, 0.999]
weight_decay: 0.
ema:
use: true # whether to use EMA or not
updates: 1 # update at every step
device: ${device} # device for EMA, can be put on GPU if more frequent updates
decay: 0.99 # EMA decay value, if null, no EMA is used
processor:
name: multi_band_processor
use: false
n_bands: 8
num_samples: 10_000
power_std: 1.
resampling:
use: false
target_sr: 16000
filter:
use: false
n_bands: 4
idx_band: 0
cutoffs: null
schedule:
repartition: "power"
variable_step_batch: true
beta_t0: 1.0e-5
beta_t1: 2.9e-2
beta_exp: 7.5
num_steps: 1000
variance: 'beta'
clip: 5.
rescale: 1.
n_bands: null
noise_scale: 1.0
metrics:
num_stage: 4
================================================
FILE: config/solver/diffusion/encodec_24khz.yaml
================================================
# @package __global__
defaults:
- diffusion/default
- _self_
sample_rate: 24000
channels: 1
compression_model_checkpoint: //pretrained/facebook/encodec_24khz
n_q: 4 # num quantizers, 3kbps
================================================
FILE: config/solver/musicgen/debug.yaml
================================================
# @package __global__
# This is a minimal debugging configuration
# for MusicGen training solver
defaults:
- musicgen/default
- /model: lm/musicgen_lm
- override /model/lm/model_scale: xsmall
- override /dset: audio/example
- _self_
autocast: false
compression_model_checkpoint: //pretrained/debug_compression_model
transformer_lm:
n_q: 4
card: 400
codebooks_pattern:
modeling: parallel
channels: 1
sample_rate: 32000
deadlock:
use: false # deadlock detection
dataset:
batch_size: 4
segment_duration: 5
sample_on_weight: false # Uniform sampling all the way
sample_on_duration: false # Uniform sampling all the way
generate:
audio:
strategy: peak
lm:
use_sampling: false
top_k: 0
top_p: 0.0
checkpoint:
save_every: 0
keep_last: 0
optim:
epochs: 2
updates_per_epoch: 10
optimizer: adamw
lr: 1e-4
logging:
log_tensorboard: true
schedule:
lr_scheduler: null
================================================
FILE: config/solver/musicgen/default.yaml
================================================
# @package __global__
defaults:
- /solver/default
- /conditioner: none
- _self_
- /solver/musicgen/evaluation: none
- override /dset: audio/default
autocast: true
autocast_dtype: float16
solver: musicgen
sample_rate: ???
channels: ???
compression_model_checkpoint: ???
tokens:
padding_with_special_token: false
cache:
path:
write: false
write_shard: 0
write_num_shards: 1
dataset:
batch_size: 128
num_workers: 10
segment_duration: 30
min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence.
return_info: true
train:
num_samples: 1000000 # need a randomly large number here for AudioDataset
valid:
num_samples: 10000
generate:
num_samples: 50
metrics:
fad:
use_gt: false
model: tf
tf:
bin: null # path to local frechet_audio_distance code
model_path: //reference/fad/vggish_model.ckpt
kld:
use_gt: false
model: passt
passt:
pretrained_length: 20
text_consistency:
use_gt: false
model: clap
clap:
model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
model_arch: 'HTSAT-base'
enable_fusion: false
chroma_cosine:
use_gt: false
model: chroma_base
chroma_base:
sample_rate: ${sample_rate}
n_chroma: 12
radix2_exp: 14
argmax: true
generate:
every: 25
num_workers: 5
path: samples
audio:
format: wav
strategy: loudness
sample_rate: ${sample_rate}
loudness_headroom_db: 14
lm:
prompted_samples: true
unprompted_samples: true
gen_gt_samples: false
prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4
gen_duration: null # if not set, will use dataset.generate.segment_duration
remove_prompts: false
# generation params
use_sampling: false
temp: 1.0
top_k: 0
top_p: 0.0
evaluate:
every: 25
num_workers: 5
metrics:
base: false
fad: false
kld: false
text_consistency: false
chroma_cosine: false
checkpoint:
save_last: true
save_every: 50
keep_last: 10
keep_every_states: null
optim:
epochs: 200
updates_per_epoch: 2000
lr: 1e-4
optimizer: adamw
max_norm: 1.0
eager_sync: true
adam:
betas: [0.9, 0.95]
weight_decay: 0.1
eps: 1e-8
schedule:
lr_scheduler: null
================================================
FILE: config/solver/musicgen/evaluation/none.yaml
================================================
# @package __global__
dataset:
evaluate:
num_samples: 10000
================================================
FILE: config/solver/musicgen/evaluation/objective_eval.yaml
================================================
# @package __global__
# Setup for execute only on musiccaps for audio generation
# evaluation with objective metrics
# execute_only=evaluate
dataset:
max_audio_duration: null
# ensure the proper values are broadcasted here for evaluate
evaluate:
min_audio_duration: 1. # some metrics requires a minimum audio length
max_audio_duration: null # all samples from musiccaps should be < 20s
num_samples: null
segment_duration: null
generate:
min_audio_duration: 1.
max_audio_duration: null
num_samples: 500
evaluate:
metrics:
fad: true
kld: true
text_consistency: true
================================================
FILE: config/solver/musicgen/musicgen_base_32khz.yaml
================================================
# @package __global__
# This is the training loop solver
# for the base MusicGen model (text-to-music)
# on monophonic audio sampled at 32 kHz
defaults:
- musicgen/default
- /model: lm/musicgen_lm
- override /dset: audio/default
- _self_
autocast: true
autocast_dtype: float16
# EnCodec large trained on mono-channel music audio sampled at 32khz
# with a total stride of 640 leading to 50 frames/s.
# rvq.n_q=4, rvq.bins=2048, no quantization dropout
# (transformer_lm card and n_q must be compatible)
compression_model_checkpoint: //pretrained/facebook/encodec_32khz
channels: 1
sample_rate: 32000
deadlock:
use: true # deadlock detection
dataset:
batch_size: 192 # 32 GPUs
sample_on_weight: false # Uniform sampling all the way
sample_on_duration: false # Uniform sampling all the way
generate:
lm:
use_sampling: true
top_k: 250
top_p: 0.0
optim:
epochs: 500
optimizer: dadam
lr: 1
ema:
use: true
updates: 10
device: cuda
logging:
log_tensorboard: true
schedule:
lr_scheduler: cosine
cosine:
warmup: 4000
lr_min_ratio: 0.0
cycle_length: 1.0
================================================
FILE: config/solver/musicgen/musicgen_melody_32khz.yaml
================================================
# @package __global__
# This is the training loop solver
# for the melody MusicGen model (text+chroma to music)
# on monophonic audio sampled at 32 kHz
defaults:
- musicgen/default
- /model: lm/musicgen_lm
- override /conditioner: chroma2music
- override /dset: audio/default
- _self_
autocast: true
autocast_dtype: float16
# EnCodec large trained on mono-channel music audio sampled at 32khz
# with a total stride of 640 leading to 50 frames/s.
# rvq.n_q=4, rvq.bins=2048, no quantization dropout
# (transformer_lm card and n_q must be compatible)
compression_model_checkpoint: //pretrained/facebook/encodec_32khz
channels: 1
sample_rate: 32000
deadlock:
use: true # deadlock detection
dataset:
batch_size: 192 # 32 GPUs
sample_on_weight: false # Uniform sampling all the way
sample_on_duration: false # Uniform sampling all the way
generate:
lm:
use_sampling: true
top_k: 250
top_p: 0.0
optim:
epochs: 500
optimizer: dadam
lr: 1
ema:
use: true
updates: 10
device: cuda
logging:
log_tensorboard: true
schedule:
lr_scheduler: cosine
cosine:
warmup: 4000
lr_min_ratio: 0.0
cycle_length: 1.0
================================================
FILE: config/teams/default.yaml
================================================
default:
dora_dir: /tmp/audiocraft_${oc.env:USER}
partitions:
global: debug
team: debug
reference_dir: /tmp
darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc.
dora_dir: /tmp/audiocraft_${oc.env:USER}
partitions:
global: debug
team: debug
reference_dir: /tmp
================================================
FILE: config/teams/labs.yaml
================================================
aws:
dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs
partitions:
global: learnlab
team: learnlab
reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference
dataset_mappers:
"^/checkpoint/[a-z]+": "/fsx-audio-craft-llm"
fair:
dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
partitions:
global: learnlab
team: learnlab
reference_dir: /large_experiments/audiocraft/reference
dataset_mappers:
"^/datasets01/datasets01": "/datasets01"
darwin:
dora_dir: /tmp/audiocraft_${oc.env:USER}
partitions:
global: debug
team: debug
reference_dir: /tmp
rsc:
dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs
partitions:
global: learn
team: learn
reference_dir: /checkpoint/audiocraft/shared/reference
================================================
FILE: dataset/example/electro_1.json
================================================
{"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]}
================================================
FILE: dataset/example/electro_2.json
================================================
{"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []}
================================================
FILE: demos/audiogen_demo.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AudioGen\n",
"Welcome to AudioGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use AudioGen in different settings.\n",
"\n",
"First, we start by initializing AudioGen. For now, we provide only a medium sized model for AudioGen: `facebook/audiogen-medium` - 1.5B transformer decoder. \n",
"\n",
"**Important note:** This variant is different from the original AudioGen model presented at [\"AudioGen: Textually-guided audio generation\"](https://arxiv.org/abs/2209.15352) as the model architecture is similar to MusicGen with a smaller frame rate and multiple streams of tokens, allowing to reduce generation time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from audiocraft.models import AudioGen\n",
"\n",
"model = AudioGen.get_pretrained('facebook/audiogen-medium')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let us configure the generation parameters. Specifically, you can control the following:\n",
"* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
"* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
"* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
"* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
"* `duration` (float, optional): duration of the generated waveform. Defaults to 10.0.\n",
"* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
"\n",
"When left unchanged, AudioGen will revert to its default parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.set_generation_params(\n",
" use_sampling=True,\n",
" top_k=250,\n",
" duration=5\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we can go ahead and start generating sound using one of the following modes:\n",
"* Audio continuation using `model.generate_continuation`\n",
"* Text-conditional samples using `model.generate`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Audio Continuation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torchaudio\n",
"import torch\n",
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"def get_bip_bip(bip_duration=0.125, frequency=440,\n",
" duration=0.5, sample_rate=16000, device=\"cuda\"):\n",
" \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
" t = torch.arange(\n",
" int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
" wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
" tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
" envelope = (tp >= 0.5).float()\n",
" return wav * envelope"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Here we use a synthetic signal to prompt the generated audio.\n",
"res = model.generate_continuation(\n",
" get_bip_bip(0.125).expand(2, -1, -1), \n",
" 16000, ['Whistling with wind blowing', \n",
" 'Typing on a typewriter'], \n",
" progress=True)\n",
"display_audio(res, 16000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
"prompt_waveform, prompt_sr = torchaudio.load(\"../assets/sirens_and_a_humming_engine_approach_and_pass.mp3\")\n",
"prompt_duration = 2\n",
"prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
"output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n",
"display_audio(output, sample_rate=16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Text-conditional Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"output = model.generate(\n",
" descriptions=[\n",
" 'Subway train blowing its horn',\n",
" 'A cat meowing',\n",
" ],\n",
" progress=True\n",
")\n",
"display_audio(output, sample_rate=16000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: demos/musicgen_app.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
# also released under the MIT license.
import argparse
from concurrent.futures import ProcessPoolExecutor
import os
from pathlib import Path
import subprocess as sp
from tempfile import NamedTemporaryFile
import time
import typing as tp
import warnings
import torch
import gradio as gr
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen, MultiBandDiffusion
MODEL = None # Last used model
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
print(IS_BATCHED)
MAX_BATCH_SIZE = 12
BATCHED_DURATION = 15
INTERRUPTING = False
MBD = None
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
_old_call = sp.call
def _call_nostderr(*args, **kwargs):
# Avoid ffmpeg vomiting on the logs.
kwargs['stderr'] = sp.DEVNULL
kwargs['stdout'] = sp.DEVNULL
_old_call(*args, **kwargs)
sp.call = _call_nostderr
# Preallocating the pool of processes.
pool = ProcessPoolExecutor(4)
pool.__enter__()
def interrupt():
global INTERRUPTING
INTERRUPTING = True
class FileCleaner:
def __init__(self, file_lifetime: float = 3600):
self.file_lifetime = file_lifetime
self.files = []
def add(self, path: tp.Union[str, Path]):
self._cleanup()
self.files.append((time.time(), Path(path)))
def _cleanup(self):
now = time.time()
for time_added, path in list(self.files):
if now - time_added > self.file_lifetime:
if path.exists():
path.unlink()
self.files.pop(0)
else:
break
file_cleaner = FileCleaner()
def make_waveform(*args, **kwargs):
# Further remove some warnings.
be = time.time()
with warnings.catch_warnings():
warnings.simplefilter('ignore')
out = gr.make_waveform(*args, **kwargs)
print("Make a video took", time.time() - be)
return out
def load_model(version='facebook/musicgen-melody'):
global MODEL
print("Loading model", version)
if MODEL is None or MODEL.name != version:
MODEL = MusicGen.get_pretrained(version)
def load_diffusion():
global MBD
if MBD is None:
print("loading MBD")
MBD = MultiBandDiffusion.get_mbd_musicgen()
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
MODEL.set_generation_params(duration=duration, **gen_kwargs)
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
processed_melodies = []
target_sr = 32000
target_ac = 1
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
if any(m is not None for m in processed_melodies):
outputs = MODEL.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
return_tokens=USE_DIFFUSION
)
else:
outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
if USE_DIFFUSION:
outputs_diffusion = MBD.tokens_to_wav(outputs[1])
outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
outputs = outputs.detach().cpu().float()
pending_videos = []
out_wavs = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, output, MODEL.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
pending_videos.append(pool.submit(make_waveform, file.name))
out_wavs.append(file.name)
file_cleaner.add(file.name)
out_videos = [pending_video.result() for pending_video in pending_videos]
for video in out_videos:
file_cleaner.add(video)
print("batch finished", len(texts), time.time() - be)
print("Tempfiles currently stored: ", len(file_cleaner.files))
return out_videos, out_wavs
def predict_batched(texts, melodies):
max_text_length = 512
texts = [text[:max_text_length] for text in texts]
load_model('facebook/musicgen-melody')
res = _do_predictions(texts, melodies, BATCHED_DURATION)
return res
def predict_full(model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
if decoder == "MultiBand_Diffusion":
USE_DIFFUSION = True
load_diffusion()
else:
USE_DIFFUSION = False
load_model(model)
def _progress(generated, to_generate):
progress((min(generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
MODEL.set_custom_progress_callback(_progress)
videos, wavs = _do_predictions(
[text], [melody], duration, progress=True,
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
if USE_DIFFUSION:
return videos[0], wavs[0], videos[1], wavs[1]
return videos[0], wavs[0], None, None
def toggle_audio_src(choice):
if choice == "mic":
return gr.update(source="microphone", value=None, label="Microphone")
else:
return gr.update(source="upload", value=None, label="File")
def toggle_diffusion(choice):
if choice == "MultiBand_Diffusion":
return [gr.update(visible=True)] * 2
else:
return [gr.update(visible=False)] * 2
def ui_full(launch_kwargs):
with gr.Blocks() as interface:
gr.Markdown(
"""
# MusicGen
This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
a simple and controllable model for music generation
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Text(label="Input Text", interactive=True)
with gr.Column():
radio = gr.Radio(["file", "mic"], value="file",
label="Condition on a melody (optional) File or Mic")
melody = gr.Audio(source="upload", type="numpy", label="File",
interactive=True, elem_id="melody-input")
with gr.Row():
submit = gr.Button("Submit")
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
with gr.Row():
model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
"facebook/musicgen-large"],
label="Model", value="facebook/musicgen-melody", interactive=True)
with gr.Row():
decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
label="Decoder", value="Default", interactive=True)
with gr.Row():
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
with gr.Row():
topk = gr.Number(label="Top-k", value=250, interactive=True)
topp = gr.Number(label="Top-p", value=0, interactive=True)
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
with gr.Column():
output = gr.Video(label="Generated Music")
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
show_progress=False).then(predict_full, inputs=[model, decoder, text, melody, duration, topk, topp,
temperature, cfg_coef],
outputs=[output, audio_output, diffusion_output, audio_diffusion])
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
gr.Examples(
fn=predict_full,
examples=[
[
"An 80s driving pop song with heavy drums and synth pads in the background",
"./assets/bach.mp3",
"facebook/musicgen-melody",
"Default"
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
"facebook/musicgen-melody",
"Default"
],
[
"90s rock song with electric guitar and heavy drums",
None,
"facebook/musicgen-medium",
"Default"
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
"./assets/bach.mp3",
"facebook/musicgen-melody",
"Default"
],
[
"lofi slow bpm electro chill with organic samples",
None,
"facebook/musicgen-medium",
"Default"
],
[
"Punk rock with loud drum and power guitar",
None,
"facebook/musicgen-medium",
"MultiBand_Diffusion"
],
],
inputs=[text, melody, model, decoder],
outputs=[output]
)
gr.Markdown(
"""
### More details
The model will generate a short music extract based on the description you provided.
The model can generate up to 30 seconds of audio in one pass. It is now possible
to extend the generation by feeding back the end of the previous chunk of audio.
This can take a long time, and the model might lose consistency. The model might also
decide at arbitrary positions that the song ends.
**WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
are generated each time.
We present 4 model variations:
1. facebook/musicgen-melody -- a music generation model capable of generating music condition
on text and melody inputs. **Note**, you can also use text only.
2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only.
3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only.
4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only.
We also present two way of decoding the audio tokens
1. Use the default GAN based compression model
2. Use MultiBand Diffusion from (paper linknano )
When using `facebook/musicgen-melody`, you can optionally provide a reference audio from
which a broad melody will be extracted. The model will then try to follow both
the description and melody provided.
You can also use your own GPU or a Google Colab by following the instructions on our repo.
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
for more details.
"""
)
interface.queue().launch(**launch_kwargs)
def ui_batched(launch_kwargs):
with gr.Blocks() as demo:
gr.Markdown(
"""
# MusicGen
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
a simple and controllable model for music generation
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
for longer sequences, more control and no queue.
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Text(label="Describe your music", lines=2, interactive=True)
with gr.Column():
radio = gr.Radio(["file", "mic"], value="file",
label="Condition on a melody (optional) File or Mic")
melody = gr.Audio(source="upload", type="numpy", label="File",
interactive=True, elem_id="melody-input")
with gr.Row():
submit = gr.Button("Generate")
with gr.Column():
output = gr.Video(label="Generated Music")
audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
submit.click(predict_batched, inputs=[text, melody],
outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
gr.Examples(
fn=predict_batched,
examples=[
[
"An 80s driving pop song with heavy drums and synth pads in the background",
"./assets/bach.mp3",
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
],
[
"90s rock song with electric guitar and heavy drums",
None,
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
"./assets/bach.mp3",
],
[
"lofi slow bpm electro chill with organic samples",
None,
],
],
inputs=[text, melody],
outputs=[output]
)
gr.Markdown("""
### More details
The model will generate 12 seconds of audio based on the description you provided.
You can optionally provide a reference audio from which a broad melody will be extracted.
The model will then try to follow both the description and melody provided.
All samples are generated with the `melody` model.
You can also use your own GPU or a Google Colab by following the instructions on our repo.
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
for more details.
""")
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
type=str,
default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
help='IP to listen on for connections to Gradio',
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
args = parser.parse_args()
launch_kwargs = {}
launch_kwargs['server_name'] = args.listen
if args.username and args.password:
launch_kwargs['auth'] = (args.username, args.password)
if args.server_port:
launch_kwargs['server_port'] = args.server_port
if args.inbrowser:
launch_kwargs['inbrowser'] = args.inbrowser
if args.share:
launch_kwargs['share'] = args.share
# Show the interface
if IS_BATCHED:
global USE_DIFFUSION
USE_DIFFUSION = False
ui_batched(launch_kwargs)
else:
ui_full(launch_kwargs)
================================================
FILE: demos/musicgen_demo.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MusicGen\n",
"Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n",
"\n",
"First, we start by initializing MusicGen, you can choose a model from the following selection:\n",
"1. `facebook/musicgen-small` - 300M transformer decoder.\n",
"2. `facebook/musicgen-medium` - 1.5B transformer decoder.\n",
"3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.\n",
"4. `facebook/musicgen-large` - 3.3B transformer decoder.\n",
"\n",
"We will use the `facebook/musicgen-small` variant for the purpose of this demonstration."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from audiocraft.models import MusicGen\n",
"from audiocraft.models import MultiBandDiffusion\n",
"import torch\n",
"USE_DIFFUSION_DECODER = False\n",
"# Using small model, better results would be obtained with `medium` or `large`.\n",
"model = MusicGen.get_pretrained('facebook/musicgen-small')\n",
"if USE_DIFFUSION_DECODER:\n",
" mbd = MultiBandDiffusion.get_mbd_musicgen()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let us configure the generation parameters. Specifically, you can control the following:\n",
"* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
"* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
"* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
"* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
"* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n",
"* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
"\n",
"When left unchanged, MusicGen will revert to its default parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.set_generation_params(\n",
" use_sampling=True,\n",
" top_k=250,\n",
" duration=30\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we can go ahead and start generating music using one of the following modes:\n",
"* Unconditional samples using `model.generate_unconditional`\n",
"* Music continuation using `model.generate_continuation`\n",
"* Text-conditional samples using `model.generate`\n",
"* Melody-conditional samples using `model.generate_with_chroma`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Music Continuation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torchaudio\n",
"import torch\n",
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"def get_bip_bip(bip_duration=0.125, frequency=440,\n",
" duration=0.5, sample_rate=32000, device=\"cuda\"):\n",
" \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
" t = torch.arange(\n",
" int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
" wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
" tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
" envelope = (tp >= 0.5).float()\n",
" return wav * envelope"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Here we use a synthetic signal to prompt both the tonality and the BPM\n",
"# of the generated audio.\n",
"res = model.generate_continuation(\n",
" get_bip_bip(0.125).expand(2, -1, -1), \n",
" 32000, ['Jazz jazz and only jazz', \n",
" 'Heartful EDM with beautiful synths and chords'], \n",
" progress=True)\n",
"display_audio(res, 32000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
"prompt_waveform, prompt_sr = torchaudio.load(\"../assets/bach.mp3\")\n",
"prompt_duration = 2\n",
"prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
"output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=True)\n",
"display_audio(output[0], sample_rate=32000)\n",
"if USE_DIFFUSION_DECODER:\n",
" out_diffusion = mbd.tokens_to_wav(output[1])\n",
" display_audio(out_diffusion, sample_rate=32000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Text-conditional Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"output = model.generate(\n",
" descriptions=[\n",
" #'80s pop track with bassy drums and synth',\n",
" #'90s rock song with loud guitars and heavy drums',\n",
" #'Progressive rock drum and bass solo',\n",
" #'Punk Rock song with loud drum and power guitar',\n",
" #'Bluesy guitar instrumental with soulful licks and a driving rhythm section',\n",
" #'Jazz Funk song with slap bass and powerful saxophone',\n",
" 'drum and bass beat with intense percussions'\n",
" ],\n",
" progress=True, return_tokens=True\n",
")\n",
"display_audio(output[0], sample_rate=32000)\n",
"if USE_DIFFUSION_DECODER:\n",
" out_diffusion = mbd.tokens_to_wav(output[1])\n",
" display_audio(out_diffusion, sample_rate=32000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Melody-conditional Generation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torchaudio\n",
"from audiocraft.utils.notebook import display_audio\n",
"\n",
"model = MusicGen.get_pretrained('facebook/musicgen-melody')\n",
"model.set_generation_params(duration=8)\n",
"\n",
"melody_waveform, sr = torchaudio.load(\"../assets/bach.mp3\")\n",
"melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n",
"output = model.generate_with_chroma(\n",
" descriptions=[\n",
" '80s pop track with bassy drums and synth',\n",
" '90s rock song with loud guitars and heavy drums',\n",
" ],\n",
" melody_wavs=melody_waveform,\n",
" melody_sample_rate=sr,\n",
" progress=True, return_tokens=True\n",
")\n",
"display_audio(output[0], sample_rate=32000)\n",
"if USE_DIFFUSION_DECODER:\n",
" out_diffusion = mbd.tokens_to_wav(output[1])\n",
" display_audio(out_diffusion, sample_rate=32000)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"vscode": {
"interpreter": {
"hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: dockerignore
================================================
cache/
================================================
FILE: docs/AUDIOGEN.md
================================================
# AudioGen: Textually-guided audio generation
AudioCraft provides the code and a model re-implementing AudioGen, a [textually-guided audio generation][audiogen_arxiv]
model that performs text-to-sound generation.
The provided AudioGen reimplementation follows the LM model architecture introduced in [MusicGen][musicgen_arxiv]
and is a single stage auto-regressive Transformer model trained over a 16kHz
EnCodec tokenizer with 4 codebooks sampled at 50 Hz.
This model variant reaches similar audio quality than the original implementation introduced in the AudioGen publication
while providing faster generation speed given the smaller frame rate.
**Important note:** The provided models are NOT the original models used to report numbers in the
[AudioGen publication][audiogen_arxiv]. Refer to the model card to learn more about architectural changes.
Listen to samples from the **original AudioGen implementation** in our [sample page][audiogen_samples].
## Model Card
See [the model card](../model_cards/AUDIOGEN_MODEL_CARD.md).
## Installation
Please follow the AudioCraft installation instructions from the [README](../README.md).
AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters).
## API and usage
We provide a simple API and 1 pre-trained models for AudioGen:
`facebook/audiogen-medium`: 1.5B model, text to sound - [🤗 Hub](https://huggingface.co/facebook/audiogen-medium)
You can play with AudioGen by running the jupyter notebook at [`demos/audiogen_demo.ipynb`](../demos/audiogen_demo.ipynb) locally (if you have a GPU).
See after a quick example for using the API.
```python
import torchaudio
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5) # generate 5 seconds.
descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor']
wav = model.generate(descriptions) # generates 3 samples.
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
```
## Training
The [AudioGenSolver](../audiocraft/solvers/audiogen.py) implements the AudioGen's training pipeline
used to develop the released model. Note that this may not fully reproduce the results presented in the paper.
Similarly to MusicGen, it defines an autoregressive language modeling task over multiple streams of
discrete tokens extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md)
for more details on how to train such model) with dataset-specific changes for environmental sound
processing.
Note that **we do NOT provide any of the datasets** used for training AudioGen.
### Example configurations and grids
We provide configurations to reproduce the released models and our research.
AudioGen solvers configuration are available in [config/solver/audiogen](../config/solver/audiogen).
The base training configuration used for the released models is the following:
[`solver=audiogen/audiogen_base_16khz`](../config/solver/audiogen/audiogen_base_16khz.yaml)
Please find some example grids to train AudioGen at
[audiocraft/grids/audiogen](../audiocraft/grids/audiogen/).
```shell
# text-to-sound
dora grid audiogen.audiogen_base_16khz
```
### Sound dataset and metadata
AudioGen's underlying dataset is an AudioDataset augmented with description metadata.
The AudioGen dataset implementation expects the metadata to be available as `.json` files
at the same location as the audio files or through specified external folder.
Learn more in the [datasets section](./DATASETS.md).
### Evaluation stage
By default, evaluation stage is also computing the cross-entropy and the perplexity over the
evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run
or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md)
for more details on the requirements for each metric.
We provide an off-the-shelf configuration to enable running the objective metrics
for audio generation in
[config/solver/audiogen/evaluation/objective_eval](../config/solver/audiogen/evaluation/objective_eval.yaml).
One can then activate evaluation the following way:
```shell
# using the configuration
dora run solver=audiogen/debug solver/audiogen/evaluation=objective_eval
# specifying each of the fields, e.g. to activate KL computation
dora run solver=audiogen/debug evaluate.metrics.kld=true
```
See [an example evaluation grid](../audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py).
### Generation stage
The generation stage allows to generate samples conditionally and/or unconditionally and to perform
audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling
from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples
generated and the batch size used are controlled by the `dataset.generate` configuration
while the other generation parameters are defined in `generate.lm`.
```shell
# control sampling parameters
dora run solver=audiogen/debug generate.lm.gen_duration=5 generate.lm.use_sampling=true generate.lm.top_k=15
```
## More information
Refer to [MusicGen's instructions](./MUSICGEN.md).
### Learn more
Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
## Citation
AudioGen
```
@article{kreuk2022audiogen,
title={Audiogen: Textually guided audio generation},
author={Kreuk, Felix and Synnaeve, Gabriel and Polyak, Adam and Singer, Uriel and D{\'e}fossez, Alexandre and Copet, Jade and Parikh, Devi and Taigman, Yaniv and Adi, Yossi},
journal={arXiv preprint arXiv:2209.15352},
year={2022}
}
```
MusicGen
```
@article{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
year={2023},
journal={arXiv preprint arXiv:2306.05284},
}
```
## License
See license information in the [model card](../model_cards/AUDIOGEN_MODEL_CARD.md).
[audiogen_arxiv]: https://arxiv.org/abs/2209.15352
[musicgen_arxiv]: https://arxiv.org/abs/2306.05284
[audiogen_samples]: https://felixkreuk.github.io/audiogen/
================================================
FILE: docs/CONDITIONING.md
================================================
# AudioCraft conditioning modules
AudioCraft provides a
[modular implementation of conditioning modules](../audiocraft/modules/conditioners.py)
that can be used with the language model to condition the generation.
The codebase was developed in order to easily extend the set of modules
currently supported to easily develop new ways of controlling the generation.
## Conditioning methods
For now, we support 3 main types of conditioning within AudioCraft:
* Text-based conditioning methods
* Waveform-based conditioning methods
* Joint embedding conditioning methods for text and audio projected in a shared latent space.
The Language Model relies on 2 core components that handle processing information:
* The `ConditionProvider` class, that maps metadata to processed conditions leveraging
all the defined conditioners for the given task.
* The `ConditionFuser` class, that takes preprocessed conditions and properly fuse the
conditioning embedding to the language model inputs following a given fusing strategy.
Different conditioners (for text, waveform, joint embeddings...) are provided as torch
modules in AudioCraft and are used internally in the language model to process the
conditioning signals and feed them to the language model.
## Core concepts
### Conditioners
The `BaseConditioner` torch module is the base implementation for all conditioners in audiocraft.
Each conditioner is expected to implement 2 methods:
* The `tokenize` method that is used as a preprocessing method that contains all processing
that can lead to synchronization points (e.g. BPE tokenization with transfer to the GPU).
The output of the tokenize method will then be used to feed the forward method.
* The `forward` method that takes the output of the tokenize method and contains the core computation
to obtain the conditioning embedding along with a mask indicating valid indices (e.g. padding tokens).
### ConditionProvider
The ConditionProvider prepares and provides conditions given a dictionary of conditioners.
Conditioners are specified as a dictionary of attributes and the corresponding conditioner
providing the processing logic for the given attribute.
Similarly to the conditioners, the condition provider works in two steps to avoid sychronization points:
* A `tokenize` method that takes a list of conditioning attributes for the batch,
and run all tokenize steps for the set of conditioners.
* A `forward` method that takes the output of the tokenize step and run all the forward steps
for the set of conditioners.
The list of conditioning attributes is passed as a list of `ConditioningAttributes`
that is presented just below.
### ConditionFuser
Once all conditioning signals have been extracted and processed by the `ConditionProvider`
as dense embeddings, they remain to be passed to the language model along with the original
language model inputs.
The `ConditionFuser` handles specifically the logic to combine the different conditions
to the actual model input, supporting different strategies to combine them.
One can therefore define different strategies to combine or fuse the condition to the input, in particular:
* Prepending the conditioning signal to the input with the `prepend` strategy,
* Summing the conditioning signal to the input with the `sum` strategy,
* Combining the conditioning relying on a cross-attention mechanism with the `cross` strategy,
* Using input interpolation with the `input_interpolate` strategy.
### SegmentWithAttributes and ConditioningAttributes: From metadata to conditions
The `ConditioningAttributes` dataclass is the base class for metadata
containing all attributes used for conditioning the language model.
It currently supports the following types of attributes:
* Text conditioning attributes: Dictionary of textual attributes used for text-conditioning.
* Wav conditioning attributes: Dictionary of waveform attributes used for waveform-based
conditioning such as the chroma conditioning.
* JointEmbed conditioning attributes: Dictionary of text and waveform attributes
that are expected to be represented in a shared latent space.
These different types of attributes are the attributes that are processed
by the different conditioners.
`ConditioningAttributes` are extracted from metadata loaded along the audio in the datasets,
provided that the metadata used by the dataset implements the `SegmentWithAttributes` abstraction.
All metadata-enabled datasets to use for conditioning in AudioCraft inherits
the [`audiocraft.data.info_dataset.InfoAudioDataset`](../audiocraft/data/info_audio_dataset.py) class
and the corresponding metadata inherits and implements the `SegmentWithAttributes` abstraction.
Refer to the [`audiocraft.data.music_dataset.MusicAudioDataset`](../audiocraft/data/music_dataset.py)
class as an example.
## Available conditioners
### Text conditioners
All text conditioners are expected to inherit from the `TextConditioner` class.
AudioCraft currently provides two text conditioners:
* The `LUTConditioner` that relies on look-up-table of embeddings learned at train time,
and relying on either no tokenizer or a spacy tokenizer. This conditioner is particularly
useful for simple experiments and categorical labels.
* The `T5Conditioner` that relies on a
[pre-trained T5 model](https://huggingface.co/docs/transformers/model_doc/t5)
frozen or fine-tuned at train time to extract the text embeddings.
### Waveform conditioners
All waveform conditioners are expected to inherit from the `WaveformConditioner` class and
consists of conditioning method that takes a waveform as input. The waveform conditioner
must implement the logic to extract the embedding from the waveform and define the downsampling
factor from the waveform to the resulting embedding.
The `ChromaStemConditioner` conditioner is a waveform conditioner for the chroma features
conditioning used by MusicGen. It takes a given waveform, extract relevant stems for melody
(namely all non drums and bass stems) using a
[pre-trained Demucs model](https://github.com/facebookresearch/demucs)
and then extract the chromagram bins from the remaining mix of stems.
### Joint embeddings conditioners
We finally provide support for conditioning based on joint text and audio embeddings through
the `JointEmbeddingConditioner` class and the `CLAPEmbeddingConditioner` that implements such
a conditioning method relying on a [pretrained CLAP model](https://github.com/LAION-AI/CLAP).
## Classifier Free Guidance
We provide a Classifier Free Guidance implementation in AudioCraft. With the classifier free
guidance dropout, all attributes are dropped with the same probability.
## Attribute Dropout
We further provide an attribute dropout strategy. Unlike the classifier free guidance dropout,
the attribute dropout drops given attributes with a defined probability, allowing the model
not to expect all conditioning signals to be provided at once.
## Faster computation of conditions
Conditioners that require some heavy computation on the waveform can be cached, in particular
the `ChromaStemConditioner` or `CLAPEmbeddingConditioner`. You just need to provide the
`cache_path` parameter to them. We recommend running dummy jobs for filling up the cache quickly.
An example is provied in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py).
================================================
FILE: docs/DATASETS.md
================================================
# AudioCraft datasets
Our dataset manifest files consist in 1-json-per-line files, potentially gzipped,
as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio
file and associated metadata. The manifest files are then provided in the configuration,
as `datasource` sub-configuration. A datasource contains the pointers to the paths of
the manifest files for each AudioCraft stage (or split) along with additional information
(eg. maximum sample rate to use against this dataset). All the datasources are under the
`dset` group config, with a dedicated configuration file for each dataset.
## Getting started
### Example
See the provided example in the directory that provides a manifest to use the example dataset
provided under the [dataset folder](../dataset/example).
The manifest files are stored in the [egs folder](../egs/example).
```shell
egs/
example/data.json.gz
```
A datasource is defined in the configuration folder, in the dset group config for this dataset
at [config/dset/audio/example](../config/dset/audio/example.yaml):
```shell
# @package __global__
datasource:
max_sample_rate: 44100
max_channels: 2
train: egs/example
valid: egs/example
evaluate: egs/example
generate: egs/example
```
For proper dataset, one should create manifest for each of the splits and specify the correct path
to the given manifest in the datasource for each split.
Then, using a dataset through the configuration can be done pointing to the
corresponding dataset configuration:
```shell
dset= # should match the yaml file name
# for example
dset=audio/example
```
### Creating manifest files
Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use
the following command to create new manifest files from a given folder containing audio files:
```shell
python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz
# For example to generate the manifest for dset=audio/example
# note: we don't use any split and we don't compress the jsonl file for this dummy example
python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl
# More info with: python -m audiocraft.data.audio_dataset --help
```
## Additional information
### MusicDataset and metadata
The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects
the additional metadata to be stored in a JSON file that has the same path as the corresponding
audio file, but with a `.json` extension.
### SoundDataset and metadata
The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset,
the SoundDataset expects the additional metadata to be stored in a JSON file that has the same
path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset
supports an additional parameter pointing to an extra folder `external_metadata_source` containing
all the JSON metadata files given they have the same filename as the audio file.
================================================
FILE: docs/ENCODEC.md
================================================
# EnCodec: High Fidelity Neural Audio Compression
AudioCraft provides the training code for EnCodec, a state-of-the-art deep learning
based audio codec supporting both mono stereo audio, presented in the
[High Fidelity Neural Audio Compression][arxiv] paper.
Check out our [sample page][encodec_samples].
## Original EnCodec models
The EnCodec models presented in High Fidelity Neural Audio Compression can be accessed
and used with the [EnCodec repository](https://github.com/facebookresearch/encodec).
**Note**: We do not guarantee compatibility between the AudioCraft and EnCodec codebases
and released checkpoints at this stage.
## Installation
Please follow the AudioCraft installation instructions from the [README](../README.md).
## Training
The [CompressionSolver](../audiocraft/solvers/compression.py) implements the audio reconstruction
task to train an EnCodec model. Specifically, it trains an encoder-decoder with a quantization
bottleneck - a SEANet encoder-decoder with Residual Vector Quantization bottleneck for EnCodec -
using a combination of objective and perceptual losses in the forms of discriminators.
The default configuration matches a causal EnCodec training with at a single bandwidth.
### Example configuration and grids
We provide sample configuration and grids for training EnCodec models.
The compression configuration are defined in
[config/solver/compression](../config/solver/compression).
The example grids are available at
[audiocraft/grids/compression](../audiocraft/grids/compression).
```shell
# base causal encodec on monophonic audio sampled at 24 khz
dora grid compression.encodec_base_24khz
# encodec model used for MusicGen on monophonic audio sampled at 32 khz
dora grid compression.encodec_musicgen_32khz
```
### Training and valid stages
The model is trained using a combination of objective and perceptual losses.
More specifically, EnCodec is trained with the MS-STFT discriminator along with
objective losses through the use of a loss balancer to effectively weight
the different losses, in an intuitive manner.
### Evaluation stage
Evaluations metrics for audio generation:
* SI-SNR: Scale-Invariant Signal-to-Noise Ratio.
* ViSQOL: Virtual Speech Quality Objective Listener.
Note: Path to the ViSQOL binary (compiled with bazel) needs to be provided in
order to run the ViSQOL metric on the reference and degraded signals.
The metric is disabled by default.
Please refer to the [metrics documentation](../METRICS.md) to learn more.
### Generation stage
The generation stage consists in generating the reconstructed audio from samples
with the current model. The number of samples generated and the batch size used are
controlled by the `dataset.generate` configuration. The output path and audio formats
are defined in the generate stage configuration.
```shell
# generate samples every 5 epoch
dora run solver=compression/encodec_base_24khz generate.every=5
# run with a different dset
dora run solver=compression/encodec_base_24khz generate.path=
# limit the number of samples or use a different batch size
dora grid solver=compression/encodec_base_24khz dataset.generate.num_samples=10 dataset.generate.batch_size=4
```
### Playing with the model
Once you have a model trained, it is possible to get the entire solver, or just
the trained model with the following functions:
```python
from audiocraft.solvers import CompressionSolver
# If you trained a custom model with signature SIG.
model = CompressionSolver.model_from_checkpoint('//sig/SIG')
# If you want to get one of the pretrained models with the `//pretrained/` prefix.
model = CompressionSolver.model_from_checkpoint('//pretrained/facebook/encodec_32khz')
# Or load from a custom checkpoint path
model = CompressionSolver.model_from_checkpoint('/my_checkpoints/foo/bar/checkpoint.th')
# If you only want to use a pretrained model, you can also directly get it
# from the CompressionModel base model class.
from audiocraft.models import CompressionModel
# Here do not put the `//pretrained/` prefix!
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
model = CompressionModel.get_pretrained('dac_44khz')
# Finally, you can also retrieve the full Solver object, with its dataloader etc.
from audiocraft import train
from pathlib import Path
import logging
import os
import sys
# uncomment the following line if you want some detailed logs when loading a Solver.
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
# You must always run the following function from the root directory.
os.chdir(Path(train.__file__).parent.parent)
# You can also get the full solver (only for your own experiments).
# You can provide some overrides to the parameters to make things more convenient.
solver = train.get_solver_from_sig('SIG', {'device': 'cpu', 'dataset': {'batch_size': 8}})
solver.model
solver.dataloaders
```
### Importing / Exporting models
At the moment we do not have a definitive workflow for exporting EnCodec models, for
instance to Hugging Face (HF). We are working on supporting automatic convertion between
AudioCraft and Hugging Face implementations.
We still have some support for fine tuning an EnCodec model coming from HF in AudioCraft,
using for instance `continue_from=//pretrained/facebook/encodec_32k`.
An AudioCraft checkpoint can be exported in a more compact format (excluding the optimizer etc.)
using `audiocraft.utils.export.export_encodec`. For instance, you could run
```python
from audiocraft.utils import export
from audiocraft import train
xp = train.main.get_xp_from_sig('SIG')
export.export_encodec(
xp.folder / 'checkpoint.th',
'/checkpoints/my_audio_lm/compression_state_dict.bin')
from audiocraft.models import CompressionModel
model = CompressionModel.get_pretrained('/checkpoints/my_audio_lm/compression_state_dict.bin')
from audiocraft.solvers import CompressionSolver
# The two are strictly equivalent, but this function supports also loading from non already exported models.
model = CompressionSolver.model_from_checkpoint('//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin')
```
We will see then how to use this model as a tokenizer for MusicGen/Audio gen in the
[MusicGen documentation](./MUSICGEN.md).
### Learn more
Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
## Citation
```
@article{defossez2022highfi,
title={High Fidelity Neural Audio Compression},
author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
journal={arXiv preprint arXiv:2210.13438},
year={2022}
}
```
## License
See license information in the [README](../README.md).
[arxiv]: https://arxiv.org/abs/2210.13438
[encodec_samples]: https://ai.honu.io/papers/encodec/samples.html
================================================
FILE: docs/MBD.md
================================================
# MultiBand Diffusion
AudioCraft provides the code and models for MultiBand Diffusion, [From Discrete Tokens to High Fidelity Audio using MultiBand Diffusion][arxiv].
MultiBand diffusion is a collection of 4 models that can decode tokens from
EnCodec tokenizer into waveform audio. You can listen to some examples on the sample page.
## Installation
Please follow the AudioCraft installation instructions from the [README](../README.md).
## Usage
We offer a number of way to use MultiBand Diffusion:
1. The MusicGen demo includes a toggle to try diffusion decoder. You can use the demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py), or through the [MusicGen Colab](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing).
2. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU).
## API
We provide a simple API and pre-trained models for MusicGen and for EnCodec at 24 khz for 3 bitrates (1.5 kbps, 3 kbps and 6 kbps).
See after a quick example for using MultiBandDiffusion with the MusicGen API:
```python
import torchaudio
from audiocraft.models import MusicGen, MultiBandDiffusion
from audiocraft.data.audio import audio_write
model = MusicGen.get_pretrained('facebook/musicgen-melody')
mbd = MultiBandDiffusion.get_mbd_musicgen()
model.set_generation_params(duration=8) # generate 8 seconds.
wav, tokens = model.generate_unconditional(4, return_tokens=True) # generates 4 unconditional audio samples and keep the tokens for MBD generation
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
wav_diffusion = mbd.tokens_to_wav(tokens)
wav, tokens = model.generate(descriptions, return_tokens=True) # generates 3 samples and keep the tokens.
wav_diffusion = mbd.tokens_to_wav(tokens)
melody, sr = torchaudio.load('./assets/bach.mp3')
# Generates using the melody from the given audio and the provided descriptions, returns audio and audio tokens.
wav, tokens = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr, return_tokens=True)
wav_diffusion = mbd.tokens_to_wav(tokens)
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav and {idx}_diffusion.wav, with loudness normalization at -14 db LUFS for comparing the methods.
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
audio_write(f'{idx}_diffusion', wav_diffusion[idx].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
```
For the compression task (and to compare with [EnCodec](https://github.com/facebookresearch/encodec)):
```python
import torch
from audiocraft.models import MultiBandDiffusion
from encodec import EncodecModel
from audiocraft.data.audio import audio_read, audio_write
bandwidth = 3.0 # 1.5, 3.0, 6.0
mbd = MultiBandDiffusion.get_mbd_24khz(bw=bandwidth)
encodec = EncodecModel.encodec_model_24khz()
somepath = ''
wav, sr = audio_read(somepath)
with torch.no_grad():
compressed_encodec = encodec(wav)
compressed_diffusion = mbd.regenerate(wav, sample_rate=sr)
audio_write('sample_encodec', compressed_encodec.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True)
audio_write('sample_diffusion', compressed_diffusion.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True)
```
## Training
The [DiffusionSolver](../audiocraft/solvers/diffusion.py) implements our diffusion training pipeline.
It generates waveform audio conditioned on the embeddings extracted from a pre-trained EnCodec model
(see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model).
Note that **we do NOT provide any of the datasets** used for training our diffusion models.
We provide a dummy dataset containing just a few examples for illustrative purposes.
### Example configurations and grids
One can train diffusion models as described in the paper by using this [dora grid](../audiocraft/grids/diffusion/4_bands_base_32khz.py).
```shell
# 4 bands MBD trainning
dora grid diffusion.4_bands_base_32khz
```
### Learn more
Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
## Citation
```
@article{sanroman2023fromdi,
title={From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion},
author={San Roman, Robin and Adi, Yossi and Deleforge, Antoine and Serizel, Romain and Synnaeve, Gabriel and Défossez, Alexandre},
journal={arXiv preprint arXiv:},
year={2023}
}
```
## License
See license information in the [README](../README.md).
[arxiv]: https://dl.fbaipublicfiles.com/encodec/Diffusion/paper.pdf
[mbd_samples]: https://ai.honu.io/papers/mbd/
================================================
FILE: docs/METRICS.md
================================================
# AudioCraft objective metrics
In addition to training losses, AudioCraft provides a set of objective metrics
for audio synthesis and audio generation. As these metrics may require
extra dependencies and can be costly to train, they are often disabled by default.
This section provides guidance for setting up and using these metrics in
the AudioCraft training pipelines.
## Available metrics
### Audio synthesis quality metrics
#### SI-SNR
We provide an implementation of the Scale-Invariant Signal-to-Noise Ratio in PyTorch.
No specific requirement is needed for this metric. Please activate the metric at the
evaluation stage with the appropriate flag:
```shell
dora run <...> evaluate.metrics.sisnr=true
```
#### ViSQOL
We provide a Python wrapper around the ViSQOL [official implementation](https://github.com/google/visqol)
to conveniently run ViSQOL within the training pipelines.
One must specify the path to the ViSQOL installation through the configuration in order
to enable ViSQOL computations in AudioCraft:
```shell
# the first parameter is used to activate visqol computation while the second specify
# the path to visqol's library to be used by our python wrapper
dora run <...> evaluate.metrics.visqol=true metrics.visqol.bin=
```
See an example grid: [Compression with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py)
To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
instructions available in the [open source repository](https://github.com/google/visqol).
### Audio generation metrics
#### Frechet Audio Distance
Similarly to ViSQOL, we use a Python wrapper around the Frechet Audio Distance
[official implementation](https://github.com/google-research/google-research/tree/master/frechet_audio_distance)
in TensorFlow.
Note that we had to make several changes to the actual code in order to make it work.
Please refer to the [FrechetAudioDistanceMetric](../audiocraft/metrics/fad.py) class documentation
for more details. We do not plan to provide further support in obtaining a working setup for the
Frechet Audio Distance at this stage.
```shell
# the first parameter is used to activate FAD metric computation while the second specify
# the path to FAD library to be used by our python wrapper
dora run <...> evaluate.metrics.fad=true metrics.fad.bin=
```
See an example grid: [Evaluation with FAD](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py)
#### Kullback-Leibler Divergence
We provide a PyTorch implementation of the Kullback-Leibler Divergence computed over the probabilities
of the labels obtained by a state-of-the-art audio classifier. We provide our implementation of the KLD
using the [PaSST classifier](https://github.com/kkoutini/PaSST).
In order to use the KLD metric over PaSST, you must install the PaSST library as an extra dependency:
```shell
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
```
Then similarly, you can use the metric activating the corresponding flag:
```shell
# one could extend the kld metric with additional audio classifier models that can then be picked through the configuration
dora run <...> evaluate.metrics.kld=true metrics.kld.model=passt
```
#### Text consistency
We provide a text-consistency metric, similarly to the MuLan Cycle Consistency from
[MusicLM](https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in
[Make-An-Audio](https://arxiv.org/pdf/2301.12661v1.pdf).
More specifically, we provide a PyTorch implementation of a Text consistency metric
relying on a pre-trained [Contrastive Language-Audio Pretraining (CLAP)](https://github.com/LAION-AI/CLAP).
Please install the CLAP library as an extra dependency prior to using the metric:
```shell
pip install laion_clap
```
Then similarly, you can use the metric activating the corresponding flag:
```shell
# one could extend the text consistency metric with additional audio classifier models that can then be picked through the configuration
dora run ... evaluate.metrics.text_consistency=true metrics.text_consistency.model=clap
```
Note that the text consistency metric based on CLAP will require the CLAP checkpoint to be
provided in the configuration.
#### Chroma cosine similarity
Finally, as introduced in MusicGen, we provide a Chroma Cosine Similarity metric in PyTorch.
No specific requirement is needed for this metric. Please activate the metric at the
evaluation stage with the appropriate flag:
```shell
dora run ... evaluate.metrics.chroma_cosine=true
```
#### Comparing against reconstructed audio
For all the above audio generation metrics, we offer the option to compute the metric on the reconstructed audio
fed in EnCodec instead of the generated sample using the flag `.use_gt=true`.
## Example usage
You will find example of configuration for the different metrics introduced above in:
* The [musicgen's default solver](../config/solver/musicgen/default.yaml) for all audio generation metrics
* The [compression's default solver](../config/solver/compression/default.yaml) for all audio synthesis metrics
Similarly, we provide different examples in our grids:
* [Evaluation with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py)
* [Evaluation with FAD and others](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py)
================================================
FILE: docs/MUSICGEN.md
================================================
# MusicGen: Simple and Controllable Music Generation
AudioCraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv].
MusicGen is a single stage auto-regressive Transformer model trained over a 32kHz
EnCodec tokenizer with 4 codebooks sampled at 50 Hz.
Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require
a self-supervised semantic representation, and it generates all 4 codebooks in one pass. By introducing
a small delay between the codebooks, we show we can predict them in parallel, thus having only 50 auto-regressive
steps per second of audio.
Check out our [sample page][musicgen_samples] or test the available demo!
We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset
of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
## Model Card
See [the model card](../model_cards/MUSICGEN_MODEL_CARD.md).
## Installation
Please follow the AudioCraft installation instructions from the [README](../README.md).
AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters).
## Usage
We offer a number of way to interact with MusicGen:
1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co/spaces/facebook/MusicGen)
(huge thanks to all the HF team for their support).
2. You can run the extended demo on a Colab:
[colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing)
3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py).
4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU).
5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab)
which is regularly updated with contributions from @camenduru and the community.
## API
We provide a simple API and 4 pre-trained models. The pre trained models are:
- `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
- `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
- `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
- `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model.
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
GPUs will be able to generate short sequences, or longer sequences with the `facebook/musicgen-small` model.
See after a quick example for using the API.
```python
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=8) # generate 8 seconds.
wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
wav = model.generate(descriptions) # generates 3 samples.
melody, sr = torchaudio.load('./assets/bach.mp3')
# generates using the melody from the given audio and the provided descriptions.
wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
```
## 🤗 Transformers Usage
MusicGen is available in the 🤗 Transformers library from version 4.31.0 onwards, requiring minimal dependencies
and additional packages. Steps to get started:
1. First install the 🤗 [Transformers library](https://github.com/huggingface/transformers) from main:
```shell
pip install git+https://github.com/huggingface/transformers.git
```
2. Run the following Python code to generate text-conditional audio samples:
```py
from transformers import AutoProcessor, MusicgenForConditionalGeneration
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
padding=True,
return_tensors="pt",
)
audio_values = model.generate(**inputs, max_new_tokens=256)
```
3. Listen to the audio samples either in an ipynb notebook:
```py
from IPython.display import Audio
sampling_rate = model.config.audio_encoder.sampling_rate
Audio(audio_values[0].numpy(), rate=sampling_rate)
```
Or save them as a `.wav` file using a third-party library, e.g. `scipy`:
```py
import scipy
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
```
For more details on using the MusicGen model for inference using the 🤗 Transformers library, refer to the
[MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on
[Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb).
## Training
The [MusicGenSolver](../audiocraft/solvers/musicgen.py) implements MusicGen's training pipeline.
It defines an autoregressive language modeling task over multiple streams of discrete tokens
extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md)
for more details on how to train such model).
Note that **we do NOT provide any of the datasets** used for training MusicGen.
We provide a dummy dataset containing just a few examples for illustrative purposes.
Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section.
### Example configurations and grids
We provide configurations to reproduce the released models and our research.
MusicGen solvers configuration are available in [config/solver/musicgen](../config/solver/musicgen),
in particular:
* MusicGen base model for text-to-music:
[`solver=musicgen/musicgen_base_32khz`](../config/solver/musicgen/musicgen_base_32khz.yaml)
* MusicGen model with chromagram-conditioning support:
[`solver=musicgen/musicgen_melody_32khz`](../config/solver/musicgen/musicgen_melody_32khz.yaml)
We provide 3 different scales, e.g. `model/lm/model_scale=small` (300M), or `medium` (1.5B), and `large` (3.3B).
Please find some example grids to train MusicGen at
[audiocraft/grids/musicgen](../audiocraft/grids/musicgen/).
```shell
# text-to-music
dora grid musicgen.musicgen_base_32khz --dry_run --init
# melody-guided music generation
dora grid musicgen.musicgen_melody_base_32khz --dry_run --init
# Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup.
```
### Music dataset and metadata
MusicGen's underlying dataset is an AudioDataset augmented with music-specific metadata.
The MusicGen dataset implementation expects the metadata to be available as `.json` files
at the same location as the audio files. Learn more in the [datasets section](./DATASETS.md).
### Audio tokenizers
We support a number of audio tokenizers: either pretrained EnCodec models, [DAC](https://github.com/descriptinc/descript-audio-codec), or your own models.
The tokenizer is controlled with the setting `compression_model_checkpoint`.
For instance,
```bash
# Using the 32kHz EnCodec trained on music
dora run solver=musicgen/debug \
compression_model_checkpoint=//pretrained/facebook/encodec_32khz \
transformer_lm.n_q=4 transformer_lm.card=2048
# Using DAC
dora run solver=musicgen/debug \
compression_model_checkpoint=//pretrained/dac_44khz \
transformer_lm.n_q=9 transformer_lm.card=1024 \
'codebooks_pattern.delay.delays=[0,1,2,3,4,5,6,7,8]'
# Using your own model after export (see ENCODEC.md)
dora run solver=musicgen/debug \
compression_model_checkpoint=//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin \
transformer_lm.n_q=... transformer_lm.card=...
# Using your own model from its training checkpoint.
dora run solver=musicgen/debug \
compression_model_checkpoint=//sig/SIG \ # where SIG is the Dora signature of the EnCodec XP.
transformer_lm.n_q=... transformer_lm.card=...
```
**Warning:** you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. .
### Fine tuning existing models
You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular
```bash
# Using pretrained MusicGen model.
dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//pretrained/facebook/musicgen-medium conditioner=text2music
# Using another model you already trained with a Dora signature SIG.
dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//sig/SIG conditioner=text2music
# Or providing manually a path
dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=/checkpoints/my_other_xp/checkpoint.th
```
**Warning:** You are responsible for selecting the other parameters accordingly, in a way that make it compatible
with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`.
**Warning:** We currently do not support fine tuning a model with slightly different layers. If you decide
to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`.
If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict
`{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix.
### Caching of EnCodec tokens
It is possible to precompute the EnCodec tokens and other metadata.
An example of generating and using this cache provided in the [musicgen.musicgen_base_cached_32khz grid](../audiocraft/grids/musicgen/musicgen_base_cached_32khz.py).
### Evaluation stage
By default, evaluation stage is also computing the cross-entropy and the perplexity over the
evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run
or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md)
for more details on the requirements for each metric.
We provide an off-the-shelf configuration to enable running the objective metrics
for audio generation in
[config/solver/musicgen/evaluation/objective_eval](../config/solver/musicgen/evaluation/objective_eval.yaml).
One can then activate evaluation the following way:
```shell
# using the configuration
dora run solver=musicgen/debug solver/musicgen/evaluation=objective_eval
# specifying each of the fields, e.g. to activate KL computation
dora run solver=musicgen/debug evaluate.metrics.kld=true
```
See [an example evaluation grid](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py).
### Generation stage
The generation stage allows to generate samples conditionally and/or unconditionally and to perform
audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling
from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples
generated and the batch size used are controlled by the `dataset.generate` configuration
while the other generation parameters are defined in `generate.lm`.
```shell
# control sampling parameters
dora run solver=musicgen/debug generate.lm.gen_duration=10 generate.lm.use_sampling=true generate.lm.top_k=15
```
#### Listening to samples
Note that generation happens automatically every 25 epochs. You can easily access and
compare samples between models (as long as they are trained) on the same dataset using the
MOS tool. For that first `pip install Flask gunicorn`. Then
```
gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile -
```
And access the tool at [https://127.0.0.1:8895](https://127.0.0.1:8895).
### Playing with the model
Once you have launched some experiments, you can easily get access
to the Solver with the latest trained model using the following snippet.
```python
from audiocraft.solvers.musicgen import MusicGen
solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8)
solver.model
solver.dataloaders
```
### Importing / Exporting models
We do not support currently loading a model from the Hugging Face implementation or exporting to it.
If you want to export your model in a way that is compatible with `audiocraft.models.MusicGen`
API, you can run:
```python
from audiocraft.utils import export
from audiocraft import train
xp = train.main.get_xp_from_sig('SIG_OF_LM')
export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin')
# You also need to bundle the EnCodec model you used !!
## Case 1) you trained your own
xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC')
export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin')
## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix.
## This will actually not dump the actual model, simply a pointer to the right model to download.
export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin')
```
Now you can load your custom model with:
```python
import audiocraft.models
musicgen = audiocraft.models.MusicGen.get_pretrained('/checkpoints/my_audio_lm/')
```
### Learn more
Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md).
## FAQ
#### I need help on Windows
@FurkanGozukara made a complete tutorial for [AudioCraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4)
#### I need help for running the demo on Colab
Check [@camenduru tutorial on YouTube](https://www.youtube.com/watch?v=EGfxuTy9Eeo).
#### What are top-k, top-p, temperature and classifier-free guidance?
Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt).
#### Should I use FSDP or autocast ?
The two are mutually exclusive (because FSDP does autocast on its own).
You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU.
FSDP makes everything more complex but will free up some memory for the actual
activations by sharding the optimizer state.
## Citation
```
@article{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
year={2023},
journal={arXiv preprint arXiv:2306.05284},
}
```
## License
See license information in the [model card](../model_cards/MUSICGEN_MODEL_CARD.md).
[arxiv]: https://arxiv.org/abs/2306.05284
[musicgen_samples]: https://ai.honu.io/papers/musicgen/
================================================
FILE: docs/TRAINING.md
================================================
# AudioCraft training pipelines
AudioCraft training pipelines are built on top of PyTorch as our core deep learning library
and [Flashy](https://github.com/facebookresearch/flashy) as our training pipeline design library,
and [Dora](https://github.com/facebookresearch/dora) as our experiment manager.
AudioCraft training pipelines are designed to be research and experiment-friendly.
## Environment setup
For the base installation, follow the instructions from the [README.md](../README.md).
Below are some additional instructions for setting up environment to train new models.
### Team and cluster configuration
In order to support multiple teams and clusters, AudioCraft uses an environment configuration.
The team configuration allows to specify cluster-specific configurations (e.g. SLURM configuration),
or convenient mapping of paths between the supported environments.
Each team can have a yaml file under the [configuration folder](../config). To select a team set the
`AUDIOCRAFT_TEAM` environment variable to a valid team name (e.g. `labs` or `default`):
```shell
conda env config vars set AUDIOCRAFT_TEAM=default
```
Alternatively, you can add it to your `.bashrc`:
```shell
export AUDIOCRAFT_TEAM=default
```
If not defined, the environment will default to the `default` team.
The cluster is automatically detected, but it is also possible to override it by setting
the `AUDIOCRAFT_CLUSTER` environment variable.
Based on this team and cluster, the environment is then configured with:
* The dora experiment outputs directory.
* The available slurm partitions: categorized by global and team.
* A shared reference directory: In order to facilitate sharing research models while remaining
agnostic to the used compute cluster, we created the `//reference` symbol that can be used in
YAML config to point to a defined reference folder containing shared checkpoints
(e.g. baselines, models for evaluation...).
**Important:** The default output dir for trained models and checkpoints is under `/tmp/`. This is suitable
only for quick testing. If you are doing anything serious you MUST edit the file `default.yaml` and
properly set the `dora_dir` entries.
#### Overriding environment configurations
You can set the following environmet variables to bypass the team's environment configuration:
* `AUDIOCRAFT_CONFIG`: absolute path to a team config yaml file.
* `AUDIOCRAFT_DORA_DIR`: absolute path to a custom dora directory.
* `AUDIOCRAFT_REFERENCE_DIR`: absolute path to the shared reference directory.
## Training pipelines
Each task supported in AudioCraft has its own training pipeline and dedicated solver.
Learn more about solvers and key designs around AudioCraft training pipeline below.
Please refer to the documentation of each task and model for specific information on a given task.
### Solvers
The core training component in AudioCraft is the solver. A solver holds the definition
of how to solve a given task: It implements the training pipeline logic, combining the datasets,
model, optimization criterion and components and the full training loop. We refer the reader
to [Flashy](https://github.com/facebookresearch/flashy) for core principles around solvers.
AudioCraft proposes an initial solver, the `StandardSolver` that is used as the base implementation
for downstream solvers. This standard solver provides a nice base management of logging,
checkpoints loading/saving, xp restoration, etc. on top of the base Flashy implementation.
In AudioCraft, we made the assumption that all tasks are following the same set of stages:
train, valid, evaluate and generation, each relying on a dedicated dataset.
Each solver is responsible for defining the task to solve and the associated stages
of the training loop in order to leave the full ownership of the training pipeline
to the researchers. This includes loading the datasets, building the model and
optimisation components, registering them and defining the execution of each stage.
To create a new solver for a given task, one should extend the StandardSolver
and define each stage of the training loop. One can further customise its own solver
starting from scratch instead of inheriting from the standard solver.
```python
from . import base
from .. import optim
class MyNewSolver(base.StandardSolver):
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
# one can add custom attributes to the solver
self.criterion = torch.nn.L1Loss()
def best_metric(self):
# here optionally specify which metric to use to keep track of best state
return 'loss'
def build_model(self):
# here you can instantiate your models and optimization related objects
# this method will be called by the StandardSolver init method
self.model = ...
# the self.cfg attribute contains the raw configuration
self.optimizer = optim.build_optimizer(self.model.parameters(), self.cfg.optim)
# don't forget to register the states you'd like to include in your checkpoints!
self.register_stateful('model', 'optimizer')
# keep the model best state based on the best value achieved at validation for the given best_metric
self.register_best('model')
# if you want to add EMA around the model
self.register_ema('model')
def build_dataloaders(self):
# here you can instantiate your dataloaders
# this method will be called by the StandardSolver init method
self.dataloaders = ...
...
# For both train and valid stages, the StandardSolver relies on
# a share common_train_valid implementation that is in charge of
# accessing the appropriate loader, iterate over the data up to
# the specified number of updates_per_epoch, run the ``run_step``
# function that you need to implement to specify the behavior
# and finally update the EMA and collect the metrics properly.
@abstractmethod
def run_step(self, idx: int, batch: tp.Any, metrics: dict):
"""Perform one training or valid step on a given batch.
"""
... # provide your implementation of the solver over a batch
def train(self):
"""Train stage.
"""
return self.common_train_valid('train')
def valid(self):
"""Valid stage.
"""
return self.common_train_valid('valid')
@abstractmethod
def evaluate(self):
"""Evaluate stage.
"""
... # provide your implementation here!
@abstractmethod
def generate(self):
"""Generate stage.
"""
... # provide your implementation here!
```
### About Epochs
AudioCraft Solvers uses the concept of Epoch. One epoch doesn't necessarily mean one pass over the entire
dataset, but instead represent the smallest amount of computation that we want to work with before checkpointing.
Typically, we find that having an Epoch time around 30min is ideal both in terms of safety (checkpointing often enough)
and getting updates often enough. One Epoch is at least a `train` stage that lasts for `optim.updates_per_epoch` (2000 by default),
and a `valid` stage. You can control how long the valid stage takes with `dataset.valid.num_samples`.
Other stages (`evaluate`, `generate`) will only happen every X epochs, as given by `evaluate.every` and `generate.every`).
### Models
In AudioCraft, a model is a container object that wraps one or more torch modules together
with potential processing logic to use in a solver. For example, a model would wrap an encoder module,
a quantisation bottleneck module, a decoder and some tensor processing logic. Each of the previous components
can be considered as a small « model unit » on its own but the container model is a practical component
to manipulate and train a set of modules together.
### Datasets
See the [dedicated documentation on datasets](./DATASETS.md).
### Metrics
See the [dedicated documentation on metrics](./METRICS.md).
### Conditioners
AudioCraft language models can be conditioned in various ways and the codebase offers a modular implementation
of different conditioners that can be potentially combined together.
Learn more in the [dedicated documentation on conditioning](./CONDITIONING.md).
### Configuration
AudioCraft's configuration is defined in yaml files and the framework relies on
[hydra](https://hydra.cc/docs/intro/) and [omegaconf](https://omegaconf.readthedocs.io/) to parse
and manipulate the configuration through Dora.
##### :warning: Important considerations around configurations
Our configuration management relies on Hydra and the concept of group configs to structure
and compose configurations. Updating the root default configuration files will then have
an impact on all solvers and tasks.
**One should never change the default configuration files. Instead they should use Hydra config groups in order to store custom configuration.**
Once this configuration is created and used for running experiments, you should not edit it anymore.
Note that as we are using Dora as our experiment manager, all our experiment tracking is based on
signatures computed from delta between configurations.
**One must therefore ensure backward compatibilty of the configuration at all time.**
See [Dora's README](https://github.com/facebookresearch/dora) and the
[section below introduction Dora](#running-experiments-with-dora).
##### Configuration structure
The configuration is organized in config groups:
* `conditioner`: default values for conditioning modules.
* `dset`: contains all data source related information (paths to manifest files
and metadata for a given dataset).
* `model`: contains configuration for each model defined in AudioCraft and configurations
for different variants of models.
* `solver`: contains the default configuration for each solver as well as configuration
for each solver task, combining all the above components.
* `teams`: contains the cluster configuration per teams. See environment setup for more details.
The `config.yaml` file is the main configuration that composes the above groups
and contains default configuration for AudioCraft.
##### Solver's core configuration structure
The core configuration structure shared across solver is available in `solvers/default.yaml`.
##### Other configuration modules
AudioCraft configuration contains the different setups we used for our research and publications.
## Running experiments with Dora
### Launching jobs
Try launching jobs for different tasks locally with dora run:
```shell
# run compression task with lightweight encodec
dora run solver=compression/debug
```
Most of the time, the jobs are launched through dora grids, for example:
```shell
# run compression task through debug grid
dora grid compression.debug
```
Learn more about running experiments with Dora below.
### A small introduction to Dora
[Dora](https://github.com/facebookresearch/dora) is the experiment manager tool used in AudioCraft.
Check out the README to learn how Dora works. Here is a quick summary of what to know:
* An XP is a unique set of hyper-parameters with a given signature. The signature is a hash
of those hyper-parameters. We always refer to an XP with its signature, e.g. 9357e12e. We will see
after that one can retrieve the hyper-params and re-rerun it in a single command.
* In fact, the hash is defined as a delta between the base config and the one obtained
with the config overrides you passed from the command line. This means you must never change
the `conf/**.yaml` files directly., except for editing things like paths. Changing the default values
in the config files means the XP signature won't reflect that change, and wrong checkpoints might be reused.
I know, this is annoying, but the reason is that otherwise, any change to the config file would mean
that all XPs ran so far would see their signature change.
#### Dora commands
```shell
dora info -f 81de367c # this will show the hyper-parameter used by a specific XP.
# Be careful some overrides might present twice, and the right most one
# will give you the right value for it.
dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c.
# `-d` is for distributed, it will use all available GPUs.
dora run -d -f 81de367c dataset.batch_size=32 # start from the config of XP 81de367c but change some hyper-params.
# This will give you a new XP with a new signature (e.g. 3fe9c332).
dora info -f SIG -t # will tail the log (if the XP has scheduled).
# if you need to access the logs of the process for rank > 0, in particular because a crash didn't happen in the main
# process, then use `dora info -f SIG` to get the main log name (finished into something like `/5037674_0_0_log.out`)
# and worker K can accessed as `/5037674_0_{K}_log.out`.
# This is only for scheduled jobs, for local distributed runs with `-d`, then you should go into the XP folder,
# and look for `worker_{K}.log` logs.
```
An XP runs from a specific folder based on its signature, under the
`//experiments/audiocraft/outputs/` folder.
You can safely interrupt a training and resume it, it will reuse any existing checkpoint,
as it will reuse the same folder. If you made some change to the code and need to ignore
a previous checkpoint you can use `dora run --clear [RUN ARGS]`.
If you have a Slurm cluster, you can also use the dora grid command, e.g.
```shell
# run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py`
dora grid my_grid_folder.my_grid_name
# Run the following will simply display the grid and also initialized the Dora experiments database.
# You can then simply refer to a config using its signature (e.g. as `dora run -f SIG`).
dora grid my_grid_folder.my_grid_name --dry_run --init
```
Please refer to the [Dora documentation](https://github.com/facebookresearch/dora) for more information.
#### Clearing up past experiments
```shell
# This will cancel all the XPs and delete their folder and checkpoints.
# It will then reschedule them starting from scratch.
dora grid my_grid_folder.my_grid_name --clear
# The following will delete the folder and checkpoint for a single XP,
# and then run it afresh.
dora run [-f BASE_SIG] [ARGS] --clear
```
================================================
FILE: egs/example/data.jsonl
================================================
{"path": "dataset/example/electro_1.mp3", "duration": 15.024, "sample_rate": 48000, "amplitude": null, "weight": null, "info_path": null}
{"path": "dataset/example/electro_2.mp3", "duration": 20.035918367346937, "sample_rate": 44100, "amplitude": null, "weight": null, "info_path": null}
================================================
FILE: model_cards/AUDIOGEN_MODEL_CARD.md
================================================
# AudioGen Model Card
## Model details
**Organization developing the model:** The FAIR team of Meta AI.
**Model date:** This version of AudioGen was trained between July 2023 and August 2023.
**Model version:** This is version 2 of the model, not to be confused with the original AudioGen model published in ["AudioGen: Textually Guided Audio Generation"][audiogen].
In this version (v2), AudioGen was trained on the same data, but with some other differences:
1. This model was trained on 10 seconds (vs. 5 seconds in v1).
2. The discrete representation used under the hood is extracted using a retrained EnCodec model on the environmental sound data, following the EnCodec setup detailed in the ["Simple and Controllable Music Generation" paper][musicgen].
3. No audio mixing augmentations.
**Model type:** AudioGen consists of an EnCodec model for audio tokenization, and an auto-regressive language model based on the transformer architecture for audio modeling. The released model has 1.5B parameters.
**Paper or resource for more information:** More information can be found in the paper [AudioGen: Textually Guided Audio Generation](https://arxiv.org/abs/2209.15352).
**Citation details:** See [AudioGen paper][audiogen]
**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
**Where to send questions or comments about the model:** Questions and comments about AudioGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
## Intended use
**Primary intended use:** The primary use of AudioGen is research on AI-based audio generation, including:
- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
- Generation of sound guided by text to understand current abilities of generative AI models by machine learning amateurs
**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
**Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate audio pieces that create hostile or alienating environments for people. This includes generating audio that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
## Metrics
**Models performance measures:** We used the following objective measure to evaluate the model on a standard audio benchmark:
- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
- Overall quality of the audio samples;
- Text relevance to the provided text input;
More details on performance measures and human studies can be found in the paper.
**Decision thresholds:** Not applicable.
## Evaluation datasets
The model was evaluated on the [AudioCaps benchmark](https://audiocaps.github.io/).
## Training datasets
The model was trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects).
## Evaluation results
Below are the objective metrics obtained with the released model on AudioCaps (consisting of 10-second long samples). Note that the model differs from the original AudioGen model introduced in the paper, hence the difference in the metrics.
| Model | Frechet Audio Distance | KLD | Text consistency |
|---|---|---|---|
| facebook/audiogen-medium | 1.77 | 1.58 | 0.30 |
More information can be found in the paper [AudioGen: Textually Guided Audio Generation][audiogen], in the Experiments section.
## Limitations and biases
**Limitations:**
- The model is not able to generate realistic vocals.
- The model has been trained with English descriptions and will not perform as well in other languages.
- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
**Biases:** The datasets used for training may be lacking of diversity and are not representative of all possible sound events. The generated samples from the model will reflect the biases from the training data.
**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
**Use cases:** Users must be aware of the biases, limitations and risks of the model. AudioGen is a model developed for artificial intelligence research on audio generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
[musicgen]: https://arxiv.org/abs/2306.05284
[audiogen]: https://arxiv.org/abs/2209.15352
================================================
FILE: model_cards/MUSICGEN_MODEL_CARD.md
================================================
# MusicGen Model Card
## Model details
**Organization developing the model:** The FAIR team of Meta AI.
**Model date:** MusicGen was trained between April 2023 and May 2023.
**Model version:** This is the version 1 of the model.
**Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation.
**Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
**Citation details:** See [our paper][arxiv]
**License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
**Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
## Intended use
**Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
- Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
- Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs
**Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
**Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
## Metrics
**Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark:
- Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
- Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
- CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model
Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
- Overall quality of the music samples;
- Text relevance to the provided text input;
- Adherence to the melody for melody-guided music generation.
More details on performance measures and human studies can be found in the paper.
**Decision thresholds:** Not applicable.
## Evaluation datasets
The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set.
## Training datasets
The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
## Evaluation results
Below are the objective metrics obtained on MusicCaps with the released model. Note that for the publicly released models, we had all the datasets go through a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs), in order to keep only the instrumental part. This explains the difference in objective metrics with the models used in the paper.
| Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity |
|---|---|---|---|---|
| facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - |
| facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - |
| facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - |
| facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 |
More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section.
## Limitations and biases
**Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
**Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs).
**Limitations:**
- The model is not able to generate realistic vocals.
- The model has been trained with English descriptions and will not perform as well in other languages.
- The model does not perform equally well for all music styles and cultures.
- The model sometimes generates end of songs, collapsing to silence.
- It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
**Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive.
**Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
[arxiv]: https://arxiv.org/abs/2306.05284
================================================
FILE: models/Put your models here.txt
================================================
nothing here
================================================
FILE: mypy.ini
================================================
[mypy]
[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*]
ignore_missing_imports = True
================================================
FILE: requirements.txt
================================================
# please make sure you have already a pytorch install that is cuda enabled!
av
einops
flashy>=0.0.1
hydra-core>=1.1
hydra_colorlog
julius
num2words
numpy
sentencepiece
spacy==3.5.2
torch>=2.0.0
torchaudio>=2.0.0
huggingface_hub
tqdm
transformers>=4.31.0 # need Encodec there.
xformers
demucs
librosa
gradio
torchmetrics
encodec
pytaglib
================================================
FILE: scripts/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: scripts/mos.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
To run this script, from the root of the repo. Make sure to have Flask installed
FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567
# or if you have gunicorn
gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile -
"""
from collections import defaultdict
from functools import wraps
from hashlib import sha1
import json
import math
from pathlib import Path
import random
import typing as tp
from flask import Flask, redirect, render_template, request, session, url_for
from audiocraft import train
from audiocraft.utils.samples.manager import get_samples_for_xps
SAMPLES_PER_PAGE = 8
MAX_RATING = 5
storage = Path(train.main.dora.dir / 'mos_storage')
storage.mkdir(exist_ok=True)
surveys = storage / 'surveys'
surveys.mkdir(exist_ok=True)
magma_root = Path(train.__file__).parent.parent
app = Flask('mos', static_folder=str(magma_root / 'scripts/static'),
template_folder=str(magma_root / 'scripts/templates'))
app.secret_key = b'audiocraft makes the best songs'
def normalize_path(path: Path):
"""Just to make path a bit nicer, make them relative to the Dora root dir.
"""
path = path.resolve()
dora_dir = train.main.dora.dir.resolve() / 'xps'
return path.relative_to(dora_dir)
def get_full_path(normalized_path: Path):
"""Revert `normalize_path`.
"""
return train.main.dora.dir.resolve() / 'xps' / normalized_path
def get_signature(xps: tp.List[str]):
"""Return a signature for a list of XP signatures.
"""
return sha1(json.dumps(xps).encode()).hexdigest()[:10]
def ensure_logged(func):
"""Ensure user is logged in.
"""
@wraps(func)
def _wrapped(*args, **kwargs):
user = session.get('user')
if user is None:
return redirect(url_for('login', redirect_to=request.url))
return func(*args, **kwargs)
return _wrapped
@app.route('/login', methods=['GET', 'POST'])
def login():
"""Login user if not already, then redirect.
"""
user = session.get('user')
if user is None:
error = None
if request.method == 'POST':
user = request.form['user']
if not user:
error = 'User cannot be empty'
if user is None or error:
return render_template('login.html', error=error)
assert user
session['user'] = user
redirect_to = request.args.get('redirect_to')
if redirect_to is None:
redirect_to = url_for('index')
return redirect(redirect_to)
@app.route('/', methods=['GET', 'POST'])
@ensure_logged
def index():
"""Offer to create a new study.
"""
errors = []
if request.method == 'POST':
xps_or_grids = [part.strip() for part in request.form['xps'].split()]
xps = set()
for xp_or_grid in xps_or_grids:
xp_path = train.main.dora.dir / 'xps' / xp_or_grid
if xp_path.exists():
xps.add(xp_or_grid)
continue
grid_path = train.main.dora.dir / 'grids' / xp_or_grid
if grid_path.exists():
for child in grid_path.iterdir():
if child.is_symlink():
xps.add(child.name)
continue
errors.append(f'{xp_or_grid} is neither an XP nor a grid!')
assert xps or errors
blind = 'true' if request.form.get('blind') == 'on' else 'false'
xps = list(xps)
if not errors:
signature = get_signature(xps)
manifest = {
'xps': xps,
}
survey_path = surveys / signature
survey_path.mkdir(exist_ok=True)
with open(survey_path / 'manifest.json', 'w') as f:
json.dump(manifest, f, indent=2)
return redirect(url_for('survey', blind=blind, signature=signature))
return render_template('index.html', errors=errors)
@app.route('/survey/', methods=['GET', 'POST'])
@ensure_logged
def survey(signature):
success = request.args.get('success', False)
seed = int(request.args.get('seed', 4321))
blind = request.args.get('blind', 'false') in ['true', 'on', 'True']
exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True']
exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True']
max_epoch = int(request.args.get('max_epoch', '-1'))
survey_path = surveys / signature
assert survey_path.exists(), survey_path
user = session['user']
result_folder = survey_path / 'results'
result_folder.mkdir(exist_ok=True)
result_file = result_folder / f'{user}_{seed}.json'
with open(survey_path / 'manifest.json') as f:
manifest = json.load(f)
xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']]
names, ref_name = train.main.get_names(xps)
samples_kwargs = {
'exclude_prompted': exclude_prompted,
'exclude_unprompted': exclude_unprompted,
'max_epoch': max_epoch,
}
matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) # fetch latest epoch
models_by_id = {
id: [{
'xp': xps[idx],
'xp_name': names[idx],
'model_id': f'{xps[idx].sig}-{sample.id}',
'sample': sample,
'is_prompted': sample.prompt is not None,
'errors': [],
} for idx, sample in enumerate(samples)]
for id, samples in matched_samples.items()
}
experiments = [
{'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch}
for idx, xp in enumerate(xps)
]
keys = list(matched_samples.keys())
keys.sort()
rng = random.Random(seed)
rng.shuffle(keys)
model_ids = keys[:SAMPLES_PER_PAGE]
if blind:
for key in model_ids:
rng.shuffle(models_by_id[key])
ok = True
if request.method == 'POST':
all_samples_results = []
for id in model_ids:
models = models_by_id[id]
result = {
'id': id,
'is_prompted': models[0]['is_prompted'],
'models': {}
}
all_samples_results.append(result)
for model in models:
rating = request.form[model['model_id']]
if rating:
rating = int(rating)
assert rating <= MAX_RATING and rating >= 1
result['models'][model['xp'].sig] = rating
model['rating'] = rating
else:
ok = False
model['errors'].append('Please rate this model.')
if ok:
result = {
'results': all_samples_results,
'seed': seed,
'user': user,
'blind': blind,
'exclude_prompted': exclude_prompted,
'exclude_unprompted': exclude_unprompted,
}
print(result)
with open(result_file, 'w') as f:
json.dump(result, f)
seed = seed + 1
return redirect(url_for(
'survey', signature=signature, blind=blind, seed=seed,
exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted,
max_epoch=max_epoch, success=True))
ratings = list(range(1, MAX_RATING + 1))
return render_template(
'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success,
exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch,
experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[],
ref_name=ref_name, already_filled=result_file.exists())
@app.route('/audio/')
def audio(path: str):
full_path = Path('/') / path
assert full_path.suffix in [".mp3", ".wav"]
return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'}
def mean(x):
return sum(x) / len(x)
def std(x):
m = mean(x)
return math.sqrt(sum((i - m)**2 for i in x) / len(x))
@app.route('/results/')
@ensure_logged
def results(signature):
survey_path = surveys / signature
assert survey_path.exists(), survey_path
result_folder = survey_path / 'results'
result_folder.mkdir(exist_ok=True)
# ratings per model, then per user.
ratings_per_model = defaultdict(list)
users = []
for result_file in result_folder.iterdir():
if result_file.suffix != '.json':
continue
with open(result_file) as f:
results = json.load(f)
users.append(results['user'])
for result in results['results']:
for sig, rating in result['models'].items():
ratings_per_model[sig].append(rating)
fmt = '{:.2f}'
models = []
for model in sorted(ratings_per_model.keys()):
ratings = ratings_per_model[model]
models.append({
'sig': model,
'samples': len(ratings),
'mean_rating': fmt.format(mean(ratings)),
# the value 1.96 was probably chosen to achieve some
# confidence interval assuming gaussianity.
'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5),
})
return render_template('results.html', signature=signature, models=models, users=users)
================================================
FILE: scripts/resample_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Resampling script.
"""
import argparse
from pathlib import Path
import shutil
import typing as tp
import submitit
import tqdm
from audiocraft.data.audio import audio_read, audio_write
from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files
from audiocraft.data.audio_utils import convert_audio
from audiocraft.environment import AudioCraftEnvironment
def read_txt_files(path: tp.Union[str, Path]):
with open(args.files_path) as f:
lines = [line.rstrip() for line in f]
print(f"Read {len(lines)} in .txt")
lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']]
print(f"Filtered and keep {len(lines)} from .txt")
return lines
def read_egs_files(path: tp.Union[str, Path]):
path = Path(path)
if path.is_dir():
if (path / 'data.jsonl').exists():
path = path / 'data.jsonl'
elif (path / 'data.jsonl.gz').exists():
path = path / 'data.jsonl.gz'
else:
raise ValueError("Don't know where to read metadata from in the dir. "
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
meta = load_audio_meta(path)
return [m.path for m in meta]
def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None):
if task_index is None:
env = submitit.JobEnvironment()
task_index = env.global_rank
shard_index = node_index * args.tasks_per_node + task_index
if args.files_path is None:
lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)]
else:
files_path = Path(args.files_path)
if files_path.suffix == '.txt':
print(f"Reading file list from .txt file: {args.files_path}")
lines = read_txt_files(args.files_path)
else:
print(f"Reading file list from egs: {args.files_path}")
lines = read_egs_files(args.files_path)
total_files = len(lines)
print(
f"Total of {total_files} processed with {n_shards} shards. " +
f"Current idx = {shard_index} -> {total_files // n_shards} files to process"
)
for idx, line in tqdm.tqdm(enumerate(lines)):
# skip if not part of this shard
if idx % n_shards != shard_index:
continue
path = str(AudioCraftEnvironment.apply_dataset_mappers(line))
root_path = str(args.root_path)
if not root_path.endswith('/'):
root_path += '/'
assert path.startswith(str(root_path)), \
f"Mismatch between path and provided root: {path} VS {root_path}"
try:
metadata_path = Path(path).with_suffix('.json')
out_path = args.out_path / path[len(root_path):]
out_metadata_path = out_path.with_suffix('.json')
out_done_token = out_path.with_suffix('.done')
# don't reprocess existing files
if out_done_token.exists():
continue
print(idx, out_path, path)
mix, sr = audio_read(path)
mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0)
# enforce simple stereo
out_channels = mix_channels
if out_channels > 2:
print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels")
out_channels = 2
out_sr = args.sample_rate if args.sample_rate is not None else sr
out_wav = convert_audio(mix, sr, out_sr, out_channels)
audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr,
format=args.format, normalize=False, strategy='clip')
if metadata_path.exists():
shutil.copy(metadata_path, out_metadata_path)
else:
print(f"No metadata found at {str(metadata_path)}")
out_done_token.touch()
except Exception as e:
print(f"Error processing file line: {line}, {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Resample dataset with SLURM.")
parser.add_argument(
"--log_root",
type=Path,
default=Path.home() / 'tmp' / 'resample_logs',
)
parser.add_argument(
"--files_path",
type=Path,
help="List of files to process, either .txt (one file per line) or a jsonl[.gz].",
)
parser.add_argument(
"--root_path",
type=Path,
required=True,
help="When rewriting paths, this will be the prefix to remove.",
)
parser.add_argument(
"--out_path",
type=Path,
required=True,
help="When rewriting paths, `root_path` will be replaced by this.",
)
parser.add_argument("--xp_name", type=str, default="shutterstock")
parser.add_argument(
"--nodes",
type=int,
default=4,
)
parser.add_argument(
"--tasks_per_node",
type=int,
default=20,
)
parser.add_argument(
"--cpus_per_task",
type=int,
default=4,
)
parser.add_argument(
"--memory_gb",
type=int,
help="Memory in GB."
)
parser.add_argument(
"--format",
type=str,
default="wav",
)
parser.add_argument(
"--sample_rate",
type=int,
default=32000,
)
parser.add_argument(
"--channels",
type=int,
)
parser.add_argument(
"--partition",
default='learnfair',
)
parser.add_argument("--qos")
parser.add_argument("--account")
parser.add_argument("--timeout", type=int, default=4320)
parser.add_argument('--debug', action='store_true', help='debug mode (local run)')
args = parser.parse_args()
n_shards = args.tasks_per_node * args.nodes
if args.files_path is None:
print("Warning: --files_path not provided, not recommended when processing more than 10k files.")
if args.debug:
print("Debugging mode")
process_dataset(args, n_shards=n_shards, node_index=0, task_index=0)
else:
log_folder = Path(args.log_root) / args.xp_name / '%j'
print(f"Logging to: {log_folder}")
log_folder.parent.mkdir(parents=True, exist_ok=True)
executor = submitit.AutoExecutor(folder=str(log_folder))
if args.qos:
executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account)
else:
executor.update_parameters(slurm_partition=args.partition)
executor.update_parameters(
slurm_job_name=args.xp_name, timeout_min=args.timeout,
cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1)
if args.memory_gb:
executor.update_parameters(mem=f'{args.memory_gb}GB')
jobs = []
with executor.batch():
for node_index in range(args.nodes):
job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index)
jobs.append(job)
for job in jobs:
print(f"Waiting on job {job.job_id}")
job.results()
================================================
FILE: scripts/static/style.css
================================================
body {
background-color: #fbfbfb;
margin: 0;
}
select, input {
font-size: 1em;
max-width: 100%;
}
.xp_name {
font-family: monospace;
}
.simple_form {
background-color: #dddddd;
padding: 1em;
margin: 0.5em;
}
textarea {
margin-top: 0.5em;
margin-bottom: 0.5em;
}
.rating {
background-color: grey;
padding-top: 5px;
padding-bottom: 5px;
padding-left: 8px;
padding-right: 8px;
margin-right: 2px;
cursor:pointer;
}
.rating_selected {
background-color: purple;
}
.content {
font-family: sans-serif;
background-color: #f6f6f6;
padding: 40px;
margin: 0 auto;
max-width: 1000px;
}
.track label {
padding-top: 10px;
padding-bottom: 10px;
}
.track {
padding: 15px;
margin: 5px;
background-color: #c8c8c8;
}
.submit-big {
width:400px;
height:30px;
font-size: 20px;
}
.error {
color: red;
}
.ratings {
margin-left: 10px;
}
.important {
font-weight: bold;
}
.survey {
margin-bottom: 100px;
}
.success {
color: #25901b;
font-weight: bold;
}
.warning {
color: #8a1f19;
font-weight: bold;
}
.track>section {
display: flex;
align-items: center;
}
.prompt {
display: flex;
align-items: center;
}
.track>section>div {
padding-left: 10px;
}
audio {
max-width: 280px;
max-height: 40px;
margin-left: 10px;
margin-right: 10px;
}
.special {
font-weight: bold;
color: #2c2c2c;
}
================================================
FILE: scripts/templates/base.html
================================================
{% block head %}
AudioCraft — MOS
{% endblock %}
Welcome {{session['user']}} to the internal MOS assistant for AudioCraft.
You can create custom surveys between your models, that you can
evaluate yourself, or with the help of your teammates, by simply
sharing a link!
{% for error in errors %}
{{error}}
{% endfor %}
Samples
{% endblock %}
================================================
FILE: setup.cfg
================================================
[pep8]
max-line-length = 120
[flake8]
max-line-length = 120
[coverage:report]
include = audiocraft/*
omit =
audiocraft/environment.py
audiocraft/solvers/*
audiocraft/utils/*
audiocraft/*/loaders.py
audiocraft/*/builders.py
================================================
FILE: setup.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from setuptools import setup, find_packages
NAME = 'audiocraft'
DESCRIPTION = 'Audio generation research library for PyTorch'
URL = 'https://github.com/facebookresearch/audiocraft'
AUTHOR = 'FAIR Speech & Audio'
EMAIL = 'defossez@meta.com, jadecopet@meta.com'
REQUIRES_PYTHON = '>=3.8.0'
for line in open('audiocraft/__init__.py'):
line = line.strip()
if '__version__' in line:
context = {}
exec(line, context)
VERSION = context['__version__']
HERE = Path(__file__).parent
try:
with open(HERE / "README.md", encoding='utf-8') as f:
long_description = '\n' + f.read()
except FileNotFoundError:
long_description = DESCRIPTION
REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')]
setup(
name=NAME,
version=VERSION,
description=DESCRIPTION,
author_email=EMAIL,
long_description=long_description,
long_description_content_type='text/markdown',
author=AUTHOR,
url=URL,
python_requires=REQUIRES_PYTHON,
install_requires=REQUIRED,
extras_require={
'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'],
},
packages=find_packages(),
package_data={'audiocraft': ['py.typed']},
include_package_data=True,
license='MIT License',
classifiers=[
# Trove classifiers
# Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
'License :: OSI Approved :: MIT License',
'Topic :: Multimedia :: Sound/Audio',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
)
================================================
FILE: tests/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: tests/adversarial/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: tests/adversarial/test_discriminators.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from audiocraft.adversarial.discriminators import (
MultiPeriodDiscriminator,
MultiScaleDiscriminator,
MultiScaleSTFTDiscriminator
)
class TestMultiPeriodDiscriminator:
def test_mpd_discriminator(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
periods = [1, 2, 3]
mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C)
logits, fmaps = mpd(t0)
assert len(logits) == len(periods)
assert len(fmaps) == len(periods)
assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
class TestMultiScaleDiscriminator:
def test_msd_discriminator(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
scale_norms = ['weight_norm', 'weight_norm']
msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C)
logits, fmaps = msd(t0)
assert len(logits) == len(scale_norms)
assert len(fmaps) == len(scale_norms)
assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits])
assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
class TestMultiScaleStftDiscriminator:
def test_msstftd_discriminator(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
n_filters = 4
n_ffts = [128, 256, 64]
hop_lengths = [32, 64, 16]
win_lengths = [128, 256, 64]
msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths,
win_lengths=win_lengths, in_channels=C)
logits, fmaps = msstftd(t0)
assert len(logits) == len(n_ffts)
assert len(fmaps) == len(n_ffts)
assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
================================================
FILE: tests/adversarial/test_losses.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import random
import torch
from audiocraft.adversarial import (
AdversarialLoss,
get_adv_criterion,
get_real_criterion,
get_fake_criterion,
FeatureMatchingLoss,
MultiScaleDiscriminator,
)
class TestAdversarialLoss:
def test_adversarial_single_multidiscriminator(self):
adv = MultiScaleDiscriminator()
optimizer = torch.optim.Adam(
adv.parameters(),
lr=1e-4,
)
loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse')
adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake)
B, C, T = 4, 1, random.randint(1000, 5000)
real = torch.randn(B, C, T)
fake = torch.randn(B, C, T)
disc_loss = adv_loss.train_adv(fake, real)
assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float)
loss, loss_feat = adv_loss(fake, real)
assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float)
# we did not specify feature loss
assert loss_feat.item() == 0.
def test_adversarial_feat_loss(self):
adv = MultiScaleDiscriminator()
optimizer = torch.optim.Adam(
adv.parameters(),
lr=1e-4,
)
loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse')
feat_loss = FeatureMatchingLoss()
adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss)
B, C, T = 4, 1, random.randint(1000, 5000)
real = torch.randn(B, C, T)
fake = torch.randn(B, C, T)
loss, loss_feat = adv_loss(fake, real)
assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float)
assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float)
class TestGeneratorAdversarialLoss:
def test_hinge_generator_adv_loss(self):
adv_loss = get_adv_criterion(loss_type='hinge')
t0 = torch.randn(1, 2, 0)
t1 = torch.FloatTensor([1.0, 2.0, 3.0])
assert adv_loss(t0).item() == 0.0
assert adv_loss(t1).item() == -2.0
def test_mse_generator_adv_loss(self):
adv_loss = get_adv_criterion(loss_type='mse')
t0 = torch.randn(1, 2, 0)
t1 = torch.FloatTensor([1.0, 1.0, 1.0])
t2 = torch.FloatTensor([2.0, 5.0, 5.0])
assert adv_loss(t0).item() == 0.0
assert adv_loss(t1).item() == 0.0
assert adv_loss(t2).item() == 11.0
class TestDiscriminatorAdversarialLoss:
def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor):
disc_loss_real = get_real_criterion(loss_type)
disc_loss_fake = get_fake_criterion(loss_type)
loss = disc_loss_fake(fake) + disc_loss_real(real)
return loss
def test_hinge_discriminator_adv_loss(self):
loss_type = 'hinge'
t0 = torch.FloatTensor([0.0, 0.0, 0.0])
t1 = torch.FloatTensor([1.0, 2.0, 3.0])
assert self._disc_loss(loss_type, t0, t0).item() == 2.0
assert self._disc_loss(loss_type, t1, t1).item() == 3.0
def test_mse_discriminator_adv_loss(self):
loss_type = 'mse'
t0 = torch.FloatTensor([0.0, 0.0, 0.0])
t1 = torch.FloatTensor([1.0, 1.0, 1.0])
assert self._disc_loss(loss_type, t0, t0).item() == 1.0
assert self._disc_loss(loss_type, t1, t0).item() == 2.0
class TestFeatureMatchingLoss:
def test_features_matching_loss_base(self):
ft_matching_loss = FeatureMatchingLoss()
length = random.randrange(1, 100_000)
t1 = torch.randn(1, 2, length)
loss = ft_matching_loss([t1], [t1])
assert isinstance(loss, torch.Tensor)
assert loss.item() == 0.0
def test_features_matching_loss_raises_exception(self):
ft_matching_loss = FeatureMatchingLoss()
length = random.randrange(1, 100_000)
t1 = torch.randn(1, 2, length)
t2 = torch.randn(1, 2, length + 1)
with pytest.raises(AssertionError):
ft_matching_loss([], [])
with pytest.raises(AssertionError):
ft_matching_loss([t1], [t1, t1])
with pytest.raises(AssertionError):
ft_matching_loss([t1], [t2])
def test_features_matching_loss_output(self):
loss_nonorm = FeatureMatchingLoss(normalize=False)
loss_layer_normed = FeatureMatchingLoss(normalize=True)
length = random.randrange(1, 100_000)
t1 = torch.randn(1, 2, length)
t2 = torch.randn(1, 2, length)
assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0
assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0
t3 = torch.FloatTensor([1.0, 2.0, 3.0])
t4 = torch.FloatTensor([2.0, 10.0, 3.0])
assert loss_nonorm([t3], [t4]).item() == 3.0
assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0
assert loss_layer_normed([t3], [t4]).item() == 3.0
assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0
================================================
FILE: tests/common_utils/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .temp_utils import TempDirMixin
from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
================================================
FILE: tests/common_utils/temp_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import tempfile
class TempDirMixin:
"""Mixin to provide easy access to temp dir.
"""
temp_dir_ = None
@classmethod
def get_base_temp_dir(cls):
# If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = "AUDIOCRAFT_TEST_DIR"
if key in os.environ:
return os.environ[key]
if cls.temp_dir_ is None:
cls.temp_dir_ = tempfile.TemporaryDirectory()
return cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
if cls.temp_dir_ is not None:
try:
cls.temp_dir_.cleanup()
cls.temp_dir_ = None
except PermissionError:
# On Windows there is a know issue with `shutil.rmtree`,
# which fails intermittently.
# https://github.com/python/cpython/issues/74168
# Following the above thread, we ignore it.
pass
super().tearDownClass()
@property
def id(self):
return self.__class__.__name__
def get_temp_path(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
path = os.path.join(temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def get_temp_dir(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
path = os.path.join(temp_dir, *paths)
os.makedirs(path, exist_ok=True)
return path
================================================
FILE: tests/common_utils/wav_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import typing as tp
import torch
import torchaudio
def get_white_noise(chs: int = 1, num_frames: int = 1):
wav = torch.randn(chs, num_frames)
return wav
def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
wav = torch.randn(bs, chs, num_frames)
return wav
def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
fp = Path(path)
kwargs: tp.Dict[str, tp.Any] = {}
if fp.suffix == '.wav':
kwargs['encoding'] = 'PCM_S'
kwargs['bits_per_sample'] = 16
elif fp.suffix == '.mp3':
kwargs['compression'] = 320
torchaudio.save(str(fp), wav, sample_rate, **kwargs)
================================================
FILE: tests/data/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: tests/data/test_audio.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
import random
import numpy as np
import torch
import torchaudio
from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
from ..common_utils import TempDirMixin, get_white_noise, save_wav
class TestInfo(TempDirMixin):
def test_info_mp3(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
for sample_rate, ch in product(sample_rates, channels):
wav = get_white_noise(ch, int(sample_rate * duration))
path = self.get_temp_path('sample_wav.mp3')
save_wav(path, wav, sample_rate)
info = audio_info(path)
assert info.sample_rate == sample_rate
assert info.channels == ch
# we cannot trust torchaudio for num_frames, so we don't check
def _test_info_format(self, ext: str):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
wav = get_white_noise(ch, n_frames)
path = self.get_temp_path(f'sample_wav{ext}')
save_wav(path, wav, sample_rate)
info = audio_info(path)
assert info.sample_rate == sample_rate
assert info.channels == ch
assert np.isclose(info.duration, duration, atol=1e-5)
def test_info_wav(self):
self._test_info_format('.wav')
def test_info_flac(self):
self._test_info_format('.flac')
def test_info_ogg(self):
self._test_info_format('.ogg')
def test_info_m4a(self):
# TODO: generate m4a file programmatically
# self._test_info_format('.m4a')
pass
class TestRead(TempDirMixin):
def test_read_full_wav(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
path = self.get_temp_path('sample_wav.wav')
save_wav(path, wav, sample_rate)
read_wav, read_sr = audio_read(path)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[1] == wav.shape[1]
assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
def test_read_partial_wav(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
read_duration = torch.rand(1).item()
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
read_frames = int(sample_rate * read_duration)
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
path = self.get_temp_path('sample_wav.wav')
save_wav(path, wav, sample_rate)
read_wav, read_sr = audio_read(path, 0, read_duration)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[1] == read_frames
assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
def test_read_seek_time_wav(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
read_duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
path = self.get_temp_path('sample_wav.wav')
save_wav(path, wav, sample_rate)
seek_time = torch.rand(1).item()
read_wav, read_sr = audio_read(path, seek_time, read_duration)
seek_frames = int(sample_rate * seek_time)
expected_frames = n_frames - seek_frames
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[1] == expected_frames
assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
def test_read_seek_time_wav_padded(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
read_duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
read_frames = int(sample_rate * read_duration)
wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
path = self.get_temp_path('sample_wav.wav')
save_wav(path, wav, sample_rate)
seek_time = torch.rand(1).item()
seek_frames = int(sample_rate * seek_time)
expected_frames = n_frames - seek_frames
read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[1] == read_frames
assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
class TestAvRead(TempDirMixin):
def test_avread_seek_base(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 2.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
wav = get_white_noise(ch, n_frames)
path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
save_wav(path, wav, sample_rate)
for _ in range(100):
# seek will always load a full duration segment in the file
seek_time = random.uniform(0.0, 1.0)
seek_duration = random.uniform(0.001, 1.0)
read_wav, read_sr = _av_read(path, seek_time, seek_duration)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[-1] == int(seek_duration * sample_rate)
def test_avread_seek_partial(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
wav = get_white_noise(ch, n_frames)
path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
save_wav(path, wav, sample_rate)
for _ in range(100):
# seek will always load a partial segment
seek_time = random.uniform(0.5, 1.)
seek_duration = 1.
expected_num_frames = n_frames - int(seek_time * sample_rate)
read_wav, read_sr = _av_read(path, seek_time, seek_duration)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[-1] == expected_num_frames
def test_avread_seek_outofbound(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(sample_rate * duration)
wav = get_white_noise(ch, n_frames)
path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
save_wav(path, wav, sample_rate)
seek_time = 1.5
read_wav, read_sr = _av_read(path, seek_time, 1.)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[-1] == 0
def test_avread_seek_edge(self):
sample_rates = [8000, 16_000]
# some of these values will have
# int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
n_frames = [1000, 1001, 1002]
channels = [1, 2]
for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
duration = frames / sample_rate
wav = get_white_noise(ch, frames)
path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
save_wav(path, wav, sample_rate)
seek_time = (frames - 1) / sample_rate
seek_frames = int(seek_time * sample_rate)
read_wav, read_sr = _av_read(path, seek_time, duration)
assert read_sr == sample_rate
assert read_wav.shape[0] == wav.shape[0]
assert read_wav.shape[-1] == (frames - seek_frames)
class TestAudioWrite(TempDirMixin):
def test_audio_write_wav(self):
torch.manual_seed(1234)
sample_rates = [8000, 16_000]
n_frames = [1000, 1001, 1002]
channels = [1, 2]
strategies = ["peak", "clip", "rms"]
formats = ["wav", "mp3"]
for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
for format_, strategy in product(formats, strategies):
wav = get_white_noise(ch, frames)
path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
audio_write(path, wav, sample_rate, format_, strategy=strategy)
read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
if format_ == "wav":
assert read_wav.shape == wav.shape
if format_ == "wav" and strategy in ["peak", "rms"]:
rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
# for a Gaussian, the typical max scale will be less than ~5x the std.
# The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
# For RMS target, rescaling leaves more headroom by default, leading
# to a 20x rescaling typically
atol = (5 if strategy == "peak" else 20) / 2**15
delta = (rescaled_read_wav - wav).abs().max()
assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
formats = ["wav"] # faster unit tests
================================================
FILE: tests/data/test_audio_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from itertools import product
import json
import math
import os
import random
import typing as tp
import pytest
import torch
from torch.utils.data import DataLoader
from audiocraft.data.audio_dataset import (
AudioDataset,
AudioMeta,
_get_audio_meta,
load_audio_meta,
save_audio_meta
)
from audiocraft.data.zip import PathInZip
from ..common_utils import TempDirMixin, get_white_noise, save_wav
class TestAudioMeta(TempDirMixin):
def test_get_audio_meta(self):
sample_rates = [8000, 16_000]
channels = [1, 2]
duration = 1.
for sample_rate, ch in product(sample_rates, channels):
n_frames = int(duration * sample_rate)
wav = get_white_noise(ch, n_frames)
path = self.get_temp_path('sample.wav')
save_wav(path, wav, sample_rate)
m = _get_audio_meta(path, minimal=True)
assert m.path == path, 'path does not match'
assert m.sample_rate == sample_rate, 'sample rate does not match'
assert m.duration == duration, 'duration does not match'
assert m.amplitude is None
assert m.info_path is None
def test_save_audio_meta(self):
audio_meta = [
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
]
empty_audio_meta = []
for idx, meta in enumerate([audio_meta, empty_audio_meta]):
path = self.get_temp_path(f'data_{idx}_save.jsonl')
save_audio_meta(path, meta)
with open(path, 'r') as f:
lines = f.readlines()
read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines]
assert len(read_meta) == len(meta)
for m, read_m in zip(meta, read_meta):
assert m == read_m
def test_load_audio_meta(self):
try:
import dora
except ImportError:
dora = None # type: ignore
audio_meta = [
AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')),
AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json'))
]
empty_meta = []
for idx, meta in enumerate([audio_meta, empty_meta]):
path = self.get_temp_path(f'data_{idx}_load.jsonl')
with open(path, 'w') as f:
for m in meta:
json_str = json.dumps(m.to_dict()) + '\n'
f.write(json_str)
read_meta = load_audio_meta(path)
assert len(read_meta) == len(meta)
for m, read_m in zip(meta, read_meta):
if dora:
m.path = dora.git_save.to_absolute_path(m.path)
assert m == read_m, f'original={m}, read={read_m}'
class TestAudioDataset(TempDirMixin):
def _create_audio_files(self,
root_name: str,
num_examples: int,
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
sample_rate: int = 16_000,
channels: int = 1):
root_dir = self.get_temp_dir(root_name)
for i in range(num_examples):
if isinstance(durations, float):
duration = durations
elif isinstance(durations, tuple) and len(durations) == 1:
duration = durations[0]
elif isinstance(durations, tuple) and len(durations) == 2:
duration = random.uniform(durations[0], durations[1])
else:
assert False
n_frames = int(duration * sample_rate)
wav = get_white_noise(channels, n_frames)
path = os.path.join(root_dir, f'example_{i}.wav')
save_wav(path, wav, sample_rate)
return root_dir
def _create_audio_dataset(self,
root_name: str,
total_num_examples: int,
durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.),
sample_rate: int = 16_000,
channels: int = 1,
segment_duration: tp.Optional[float] = None,
num_examples: int = 10,
shuffle: bool = True,
return_info: bool = False):
root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels)
dataset = AudioDataset.from_path(root_dir,
minimal_meta=True,
segment_duration=segment_duration,
num_samples=num_examples,
sample_rate=sample_rate,
channels=channels,
shuffle=shuffle,
return_info=return_info)
return dataset
def test_dataset_full(self):
total_examples = 10
min_duration, max_duration = 1., 4.
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=(min_duration, max_duration),
sample_rate=sample_rate, channels=channels, segment_duration=None)
assert len(dataset) == total_examples
assert dataset.sample_rate == sample_rate
assert dataset.channels == channels
for idx in range(len(dataset)):
sample = dataset[idx]
assert sample.shape[0] == channels
assert sample.shape[1] <= int(max_duration * sample_rate)
assert sample.shape[1] >= int(min_duration * sample_rate)
def test_dataset_segment(self):
total_examples = 10
num_samples = 20
min_duration, max_duration = 1., 4.
segment_duration = 1.
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
assert len(dataset) == num_samples
assert dataset.sample_rate == sample_rate
assert dataset.channels == channels
for idx in range(len(dataset)):
sample = dataset[idx]
assert sample.shape[0] == channels
assert sample.shape[1] == int(segment_duration * sample_rate)
def test_dataset_equal_audio_and_segment_durations(self):
total_examples = 1
num_samples = 2
audio_duration = 1.
segment_duration = 1.
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples)
assert len(dataset) == num_samples
assert dataset.sample_rate == sample_rate
assert dataset.channels == channels
for idx in range(len(dataset)):
sample = dataset[idx]
assert sample.shape[0] == channels
assert sample.shape[1] == int(segment_duration * sample_rate)
# the random seek_time adds variability on audio read
sample_1 = dataset[0]
sample_2 = dataset[1]
assert not torch.allclose(sample_1, sample_2)
def test_dataset_samples(self):
total_examples = 1
num_samples = 2
audio_duration = 1.
segment_duration = 1.
sample_rate = 16_000
channels = 1
create_dataset = partial(
self._create_audio_dataset,
'dset', total_examples, durations=audio_duration, sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples,
)
dataset = create_dataset(shuffle=True)
# when shuffle = True, we have different inputs for the same index across epoch
sample_1 = dataset[0]
sample_2 = dataset[0]
assert not torch.allclose(sample_1, sample_2)
dataset_noshuffle = create_dataset(shuffle=False)
# when shuffle = False, we have same inputs for the same index across epoch
sample_1 = dataset_noshuffle[0]
sample_2 = dataset_noshuffle[0]
assert torch.allclose(sample_1, sample_2)
def test_dataset_return_info(self):
total_examples = 10
num_samples = 20
min_duration, max_duration = 1., 4.
segment_duration = 1.
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
assert len(dataset) == num_samples
assert dataset.sample_rate == sample_rate
assert dataset.channels == channels
for idx in range(len(dataset)):
sample, segment_info = dataset[idx]
assert sample.shape[0] == channels
assert sample.shape[1] == int(segment_duration * sample_rate)
assert segment_info.sample_rate == sample_rate
assert segment_info.total_frames == int(segment_duration * sample_rate)
assert segment_info.n_frames <= int(segment_duration * sample_rate)
assert segment_info.seek_time >= 0
def test_dataset_return_info_no_segment_duration(self):
total_examples = 10
num_samples = 20
min_duration, max_duration = 1., 4.
segment_duration = None
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
assert len(dataset) == total_examples
assert dataset.sample_rate == sample_rate
assert dataset.channels == channels
for idx in range(len(dataset)):
sample, segment_info = dataset[idx]
assert sample.shape[0] == channels
assert sample.shape[1] == segment_info.total_frames
assert segment_info.sample_rate == sample_rate
assert segment_info.n_frames <= segment_info.total_frames
def test_dataset_collate_fn(self):
total_examples = 10
num_samples = 20
min_duration, max_duration = 1., 4.
segment_duration = 1.
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False)
batch_size = 4
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=0
)
for idx, batch in enumerate(dataloader):
assert batch.shape[0] == batch_size
@pytest.mark.parametrize("segment_duration", [1.0, None])
def test_dataset_with_meta_collate_fn(self, segment_duration):
total_examples = 10
num_samples = 20
min_duration, max_duration = 1., 4.
segment_duration = 1.
sample_rate = 16_000
channels = 1
dataset = self._create_audio_dataset(
'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate,
channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True)
batch_size = 4
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collater,
num_workers=0
)
for idx, batch in enumerate(dataloader):
wav, infos = batch
assert wav.shape[0] == batch_size
assert len(infos) == batch_size
@pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [
[1, True, True, 0.5, 0.5, 0.0],
[1, False, True, 0.25, 0.5, 0.25],
[1, True, False, 0.666, 0.333, 0.0],
[1, False, False, 0.333, 0.333, 0.333],
[None, False, False, 0.333, 0.333, 0.333]])
def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist):
random.seed(1234)
rng = torch.Generator()
rng.manual_seed(1234)
def _get_histogram(dataset, repetitions=20_000):
counts = {file_meta.path: 0. for file_meta in meta}
for _ in range(repetitions):
file_meta = dataset.sample_file(0, rng)
counts[file_meta.path] += 1
return {name: count / repetitions for name, count in counts.items()}
meta = [
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
]
dataset = AudioDataset(
meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight,
sample_on_duration=sample_on_duration)
hist = _get_histogram(dataset)
assert math.isclose(hist['a'], a_hist, abs_tol=0.01)
assert math.isclose(hist['b'], b_hist, abs_tol=0.01)
assert math.isclose(hist['c'], c_hist, abs_tol=0.01)
def test_meta_duration_filter_all(self):
meta = [
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
]
try:
AudioDataset(meta, segment_duration=11, min_segment_ratio=1)
assert False
except AssertionError:
assert True
def test_meta_duration_filter_long(self):
meta = [
AudioMeta(path='a', duration=5, sample_rate=1, weight=2),
AudioMeta(path='b', duration=10, sample_rate=1, weight=None),
AudioMeta(path='c', duration=5, sample_rate=1, weight=0),
]
dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7)
assert len(dataset) == 2
================================================
FILE: tests/data/test_audio_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import julius
import torch
import pytest
from audiocraft.data.audio_utils import (
_clip_wav,
convert_audio_channels,
convert_audio,
normalize_audio
)
from ..common_utils import get_batch_white_noise
class TestConvertAudioChannels:
def test_convert_audio_channels_downmix(self):
b, c, t = 2, 3, 100
audio = get_batch_white_noise(b, c, t)
mixed = convert_audio_channels(audio, channels=2)
assert list(mixed.shape) == [b, 2, t]
def test_convert_audio_channels_nochange(self):
b, c, t = 2, 3, 100
audio = get_batch_white_noise(b, c, t)
mixed = convert_audio_channels(audio, channels=c)
assert list(mixed.shape) == list(audio.shape)
def test_convert_audio_channels_upmix(self):
b, c, t = 2, 1, 100
audio = get_batch_white_noise(b, c, t)
mixed = convert_audio_channels(audio, channels=3)
assert list(mixed.shape) == [b, 3, t]
def test_convert_audio_channels_upmix_error(self):
b, c, t = 2, 2, 100
audio = get_batch_white_noise(b, c, t)
with pytest.raises(ValueError):
convert_audio_channels(audio, channels=3)
class TestConvertAudio:
def test_convert_audio_channels_downmix(self):
b, c, dur = 2, 3, 4.
sr = 128
audio = get_batch_white_noise(b, c, int(sr * dur))
out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
def test_convert_audio_channels_upmix(self):
b, c, dur = 2, 1, 4.
sr = 128
audio = get_batch_white_noise(b, c, int(sr * dur))
out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
def test_convert_audio_upsample(self):
b, c, dur = 2, 1, 4.
sr = 2
new_sr = 3
audio = get_batch_white_noise(b, c, int(sr * dur))
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
assert torch.allclose(out, out_j)
def test_convert_audio_resample(self):
b, c, dur = 2, 1, 4.
sr = 3
new_sr = 2
audio = get_batch_white_noise(b, c, int(sr * dur))
out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
assert torch.allclose(out, out_j)
class TestNormalizeAudio:
def test_clip_wav(self):
b, c, dur = 2, 1, 4.
sr = 3
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
_clip_wav(audio)
assert audio.abs().max() <= 1
def test_normalize_audio_clip(self):
b, c, dur = 2, 1, 4.
sr = 3
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
norm_audio = normalize_audio(audio, strategy='clip')
assert norm_audio.abs().max() <= 1
def test_normalize_audio_rms(self):
b, c, dur = 2, 1, 4.
sr = 3
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
norm_audio = normalize_audio(audio, strategy='rms')
assert norm_audio.abs().max() <= 1
def test_normalize_audio_peak(self):
b, c, dur = 2, 1, 4.
sr = 3
audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
norm_audio = normalize_audio(audio, strategy='peak')
assert norm_audio.abs().max() <= 1
================================================
FILE: tests/losses/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: tests/losses/test_losses.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from audiocraft.losses import (
MelSpectrogramL1Loss,
MultiScaleMelSpectrogramLoss,
MRSTFTLoss,
SISNR,
STFTLoss,
)
def test_mel_l1_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050)
loss = mel_l1(t1, t2)
loss_same = mel_l1(t1, t1)
assert isinstance(loss, torch.Tensor)
assert isinstance(loss_same, torch.Tensor)
assert loss_same.item() == 0.0
def test_msspec_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050)
loss = msspec(t1, t2)
loss_same = msspec(t1, t1)
assert isinstance(loss, torch.Tensor)
assert isinstance(loss_same, torch.Tensor)
assert loss_same.item() == 0.0
def test_mrstft_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
mrstft = MRSTFTLoss()
loss = mrstft(t1, t2)
assert isinstance(loss, torch.Tensor)
def test_sisnr_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
sisnr = SISNR()
loss = sisnr(t1, t2)
assert isinstance(loss, torch.Tensor)
def test_stft_loss():
N, C, T = 2, 2, random.randrange(1000, 100_000)
t1 = torch.randn(N, C, T)
t2 = torch.randn(N, C, T)
mrstft = STFTLoss()
loss = mrstft(t1, t2)
assert isinstance(loss, torch.Tensor)
================================================
FILE: tests/models/test_audiogen.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from audiocraft.models import AudioGen
class TestAudioGenModel:
def get_audiogen(self):
ag = AudioGen.get_pretrained(name='debug', device='cpu')
ag.set_generation_params(duration=2.0, extend_stride=2.)
return ag
def test_base(self):
ag = self.get_audiogen()
assert ag.frame_rate == 25
assert ag.sample_rate == 16000
assert ag.audio_channels == 1
def test_generate_continuation(self):
ag = self.get_audiogen()
prompt = torch.randn(3, 1, 16000)
wav = ag.generate_continuation(prompt, 16000)
assert list(wav.shape) == [3, 1, 32000]
prompt = torch.randn(2, 1, 16000)
wav = ag.generate_continuation(
prompt, 16000, ['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 32000]
prompt = torch.randn(2, 1, 16000)
with pytest.raises(AssertionError):
wav = ag.generate_continuation(
prompt, 16000, ['youpi', 'lapin dort', 'one too many'])
def test_generate(self):
ag = self.get_audiogen()
wav = ag.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 32000]
def test_generate_long(self):
ag = self.get_audiogen()
ag.max_duration = 3.
ag.set_generation_params(duration=4., extend_stride=2.)
wav = ag.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 16000 * 4]
================================================
FILE: tests/models/test_encodec_model.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import torch
from audiocraft.models import EncodecModel
from audiocraft.modules import SEANetEncoder, SEANetDecoder
from audiocraft.quantization import DummyQuantizer
class TestEncodecModel:
def _create_encodec_model(self,
sample_rate: int,
channels: int,
dim: int = 5,
n_filters: int = 3,
n_residual_layers: int = 1,
ratios: list = [5, 4, 3, 2],
**kwargs):
frame_rate = np.prod(ratios)
encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
n_residual_layers=n_residual_layers, ratios=ratios)
decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
n_residual_layers=n_residual_layers, ratios=ratios)
quantizer = DummyQuantizer()
model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
sample_rate=sample_rate, channels=channels, **kwargs)
return model
def test_model(self):
random.seed(1234)
sample_rate = 24_000
channels = 1
model = self._create_encodec_model(sample_rate, channels)
for _ in range(10):
length = random.randrange(1, 10_000)
x = torch.randn(2, channels, length)
res = model(x)
assert res.x.shape == x.shape
def test_model_renorm(self):
random.seed(1234)
sample_rate = 24_000
channels = 1
model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
for _ in range(10):
length = random.randrange(1, 10_000)
x = torch.randn(2, channels, length)
codes, scales = model_nonorm.encode(x)
codes, scales = model_renorm.encode(x)
assert scales is not None
================================================
FILE: tests/models/test_multibanddiffusion.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import torch
from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess
from audiocraft.models import EncodecModel, DiffusionUnet
from audiocraft.modules import SEANetEncoder, SEANetDecoder
from audiocraft.modules.diffusion_schedule import NoiseSchedule
from audiocraft.quantization import DummyQuantizer
class TestMBD:
def _create_mbd(self,
sample_rate: int,
channels: int,
n_filters: int = 3,
n_residual_layers: int = 1,
ratios: list = [5, 4, 3, 2],
num_steps: int = 1000,
codec_dim: int = 128,
**kwargs):
frame_rate = np.prod(ratios)
encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
n_residual_layers=n_residual_layers, ratios=ratios)
decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
n_residual_layers=n_residual_layers, ratios=ratios)
quantizer = DummyQuantizer()
compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
sample_rate=sample_rate, channels=channels, **kwargs)
diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim)
schedule = NoiseSchedule(device='cpu', num_steps=num_steps)
DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule)
mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model)
return mbd
def test_model(self):
random.seed(1234)
sample_rate = 24_000
channels = 1
codec_dim = 128
mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim)
for _ in range(10):
length = random.randrange(1, 10_000)
x = torch.randn(2, channels, length)
res = mbd.regenerate(x, sample_rate)
assert res.shape == x.shape
================================================
FILE: tests/models/test_musicgen.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from audiocraft.models import MusicGen
class TestMusicGenModel:
def get_musicgen(self):
mg = MusicGen.get_pretrained(name='debug', device='cpu')
mg.set_generation_params(duration=2.0, extend_stride=2.)
return mg
def test_base(self):
mg = self.get_musicgen()
assert mg.frame_rate == 25
assert mg.sample_rate == 32000
assert mg.audio_channels == 1
def test_generate_unconditional(self):
mg = self.get_musicgen()
wav = mg.generate_unconditional(3)
assert list(wav.shape) == [3, 1, 64000]
def test_generate_continuation(self):
mg = self.get_musicgen()
prompt = torch.randn(3, 1, 32000)
wav = mg.generate_continuation(prompt, 32000)
assert list(wav.shape) == [3, 1, 64000]
prompt = torch.randn(2, 1, 32000)
wav = mg.generate_continuation(
prompt, 32000, ['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 64000]
prompt = torch.randn(2, 1, 32000)
with pytest.raises(AssertionError):
wav = mg.generate_continuation(
prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
def test_generate(self):
mg = self.get_musicgen()
wav = mg.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 64000]
def test_generate_long(self):
mg = self.get_musicgen()
mg.max_duration = 3.
mg.set_generation_params(duration=4., extend_stride=2.)
wav = mg.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 32000 * 4]
================================================
FILE: tests/modules/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
================================================
FILE: tests/modules/test_activations.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from audiocraft.modules.activations import CustomGLU
class TestActivations:
def test_custom_glu_calculation(self):
activation = CustomGLU(nn.Identity())
initial_shape = (4, 8, 8)
part_a = torch.ones(initial_shape) * 2
part_b = torch.ones(initial_shape) * -1
input = torch.cat((part_a, part_b), dim=-1)
output = activation(input)
# ensure all dimensions match initial shape
assert output.shape == initial_shape
# ensure the gating was calculated correctly a * f(b)
assert torch.all(output == -2).item()
================================================
FILE: tests/modules/test_codebooks_patterns.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from audiocraft.modules.codebooks_patterns import (
DelayedPatternProvider,
ParallelPatternProvider,
Pattern,
UnrolledPatternProvider,
)
class TestParallelPatternProvider:
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
def test_get_pattern(self, n_q: int, timesteps: int):
provider = ParallelPatternProvider(n_q)
pattern = provider.get_pattern(timesteps)
# + 1 to account for 1st step
assert len(pattern.layout) == timesteps + 1
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [8, 16, 100])
def test_pattern_content(self, n_q: int, timesteps: int):
provider = ParallelPatternProvider(n_q)
pattern = provider.get_pattern(timesteps)
for s, v in enumerate(pattern.layout):
for i, code in enumerate(v):
assert i == code.q
assert code.t == s - 1 # account for the 1st empty step
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [8, 16, 100])
def test_pattern_max_delay(self, n_q: int, timesteps: int):
provider = ParallelPatternProvider(n_q)
pattern = provider.get_pattern(timesteps)
assert pattern.max_delay == 0
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
class TestDelayedPatternProvider:
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
def test_get_pattern(self, n_q: int, timesteps: int):
delays = [
list(range(n_q)),
[0] + [1] * (n_q - 1),
[0] + [4] * (n_q - 1),
]
for delay in delays:
provider = DelayedPatternProvider(n_q, delay)
pattern = provider.get_pattern(timesteps)
# + 1 to account for 1st step
assert len(pattern.layout) == timesteps + max(delay) + 1
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [8, 16, 100])
def test_pattern_content(self, n_q: int, timesteps: int):
provider = DelayedPatternProvider(n_q)
pattern = provider.get_pattern(timesteps)
for s, v in enumerate(pattern.layout):
for i, code in enumerate(v):
assert i == code.q
assert code.t == max(0, s - code.q - 1)
@pytest.mark.parametrize("timesteps", [8, 16, 100])
@pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
def test_pattern_max_delay(self, timesteps: int, delay: list):
provider = DelayedPatternProvider(len(delay), delay)
pattern = provider.get_pattern(timesteps)
assert pattern.max_delay == max(delay)
assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
class TestUnrolledPatternProvider:
@pytest.mark.parametrize("timesteps", [0, 1, 16])
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
n_q = len(flattening)
max_delay = max(delays)
provider = UnrolledPatternProvider(n_q, flattening, delays)
pattern = provider.get_pattern(timesteps)
assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
@pytest.mark.parametrize("timesteps", [0, 1, 16])
@pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
@pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
n_q = len(flattening)
max_delay = max(delays)
provider = UnrolledPatternProvider(n_q, flattening, delays)
pattern = provider.get_pattern(timesteps)
assert pattern.max_delay == max_delay
class TestPattern:
def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
"""Reference method to build the sequence from the pattern without using fancy scatter."""
bs, n_q, T = z.shape
z = z.cpu().numpy()
assert n_q == pattern.n_q
assert T <= pattern.timesteps
inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
inp[:] = special_token
for s, v in enumerate(pattern.layout):
for (t, q) in v:
if t < T:
inp[:, q, s] = z[:, q, t]
return torch.from_numpy(inp)
def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
"""Reference method to revert the sequence from the pattern without using fancy scatter."""
z = z.cpu().numpy()
bs, n_q, S = z.shape
assert pattern.n_q == n_q
inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
inp[:] = special_token
for s, v in enumerate(pattern.layout):
for (t, q) in v:
if t < pattern.timesteps:
inp[:, q, t] = z[:, q, s]
return torch.from_numpy(inp)
def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
"""Reference method to revert the logits from the pattern without using fancy scatter."""
z = z.cpu().numpy()
bs, card, n_q, S = z.shape
assert pattern.n_q == n_q
ref_layout = pattern.layout
inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
inp[:] = special_token
for s, v in enumerate(ref_layout[1:]):
if s < S:
for (t, q) in v:
if t < pattern.timesteps:
inp[:, :, q, t] = z[:, :, q, s]
return torch.from_numpy(inp)
def _get_pattern_providers(self, n_q: int):
pattern_provider_1 = ParallelPatternProvider(n_q)
pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
pattern_provider_4 = UnrolledPatternProvider(
n_q, flattening=list(range(n_q)), delays=[0] * n_q
)
pattern_provider_5 = UnrolledPatternProvider(
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
)
pattern_provider_6 = UnrolledPatternProvider(
n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
)
return [
pattern_provider_1,
pattern_provider_2,
pattern_provider_3,
pattern_provider_4,
pattern_provider_5,
pattern_provider_6,
]
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [16, 72])
def test_build_pattern_sequence(self, n_q: int, timesteps: int):
bs = 2
card = 256
special_token = card
pattern_providers = self._get_pattern_providers(n_q)
for pattern_provider in pattern_providers:
pattern = pattern_provider.get_pattern(timesteps)
# we can correctly build the sequence from the pattern
z = torch.randint(0, card, (bs, n_q, timesteps))
ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
assert (res == ref_res).float().mean() == 1.0
# expected assertion fails on the number of timesteps
invalid_timesteps = [timesteps + 1]
if pattern.num_sequence_steps != pattern.timesteps:
invalid_timesteps.append(pattern.num_sequence_steps)
for i_timesteps in invalid_timesteps:
z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
with pytest.raises(AssertionError):
pattern.build_pattern_sequence(z2, special_token)
# expected assertion fails on the number of codebooks
invalid_qs = [0, n_q - 1, n_q + 1]
for i_q in invalid_qs:
z3 = torch.randint(0, card, (bs, i_q, timesteps))
with pytest.raises(AssertionError):
pattern.build_pattern_sequence(z3, special_token)
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [16, 72])
def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
bs = 2
card = 256
special_token = card
pattern_providers = self._get_pattern_providers(n_q)
for pattern_provider in pattern_providers:
pattern = pattern_provider.get_pattern(timesteps)
# this works assuming previous tests are successful
z = torch.randint(0, card, (bs, n_q, timesteps))
s = self.ref_build_pattern_sequence(z, pattern, special_token)
ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
# ensure our reference script retrieve the original sequence
assert z.shape == ref_out.shape
assert (z == ref_out).float().mean() == 1.0
# now we can test the scatter version
out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
assert out.shape == ref_out.shape
assert (out == ref_out).float().mean() == 1.0
@pytest.mark.parametrize("n_q", [1, 4, 32])
@pytest.mark.parametrize("timesteps", [16, 72])
@pytest.mark.parametrize("card", [1, 2, 256, 1024])
def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
bs = 2
special_token = card
logits_special_token = float('nan')
pattern_providers = self._get_pattern_providers(n_q)
for pattern_provider in pattern_providers:
pattern = pattern_provider.get_pattern(timesteps)
# this works assuming previous tests are successful
z = torch.randint(0, card, (bs, n_q, timesteps))
s = self.ref_build_pattern_sequence(z, pattern, special_token)
logits = torch.randn((bs, card, n_q, s.shape[-1]))
ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
# ensure our reference script retrieve the original sequence
assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
# now we can test the scatter version
out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
assert out.shape == ref_out.shape
assert (out == ref_out).float().mean() == 1.0
================================================
FILE: tests/modules/test_conv.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
import math
import random
import pytest
import torch
from torch import nn
from audiocraft.modules import (
NormConv1d,
NormConvTranspose1d,
StreamableConv1d,
StreamableConvTranspose1d,
pad1d,
unpad1d,
)
def test_get_extra_padding_for_conv1d():
# TODO: Implement me!
pass
def test_pad1d_zeros():
x = torch.randn(1, 1, 20)
xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
assert xp1.shape[-1] == 25
xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
assert xp2.shape[-1] == 30
xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
assert xp3.shape[-1] == 20
xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
assert xp4.shape[-1] == 60
with pytest.raises(AssertionError):
pad1d(x, (-1, 0), mode='constant', value=0.)
with pytest.raises(AssertionError):
pad1d(x, (0, -1), mode='constant', value=0.)
with pytest.raises(AssertionError):
pad1d(x, (-1, -1), mode='constant', value=0.)
def test_pad1d_reflect():
x = torch.randn(1, 1, 20)
xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
assert xp1.shape[-1] == 25
xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
assert xp2.shape[-1] == 30
xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
assert xp3.shape[-1] == 20
xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
assert xp4.shape[-1] == 60
with pytest.raises(AssertionError):
pad1d(x, (-1, 0), mode='reflect', value=0.)
with pytest.raises(AssertionError):
pad1d(x, (0, -1), mode='reflect', value=0.)
with pytest.raises(AssertionError):
pad1d(x, (-1, -1), mode='reflect', value=0.)
def test_unpad1d():
x = torch.randn(1, 1, 20)
u1 = unpad1d(x, (5, 5))
assert u1.shape[-1] == 10
u2 = unpad1d(x, (0, 5))
assert u2.shape[-1] == 15
u3 = unpad1d(x, (5, 0))
assert u3.shape[-1] == 15
u4 = unpad1d(x, (0, 0))
assert u4.shape[-1] == x.shape[-1]
with pytest.raises(AssertionError):
unpad1d(x, (-1, 0))
with pytest.raises(AssertionError):
unpad1d(x, (0, -1))
with pytest.raises(AssertionError):
unpad1d(x, (-1, -1))
class TestNormConv1d:
def test_norm_conv1d_modules(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
C_out, kernel_size, stride = 1, 4, 1
expected_out_length = int((T - kernel_size) / stride + 1)
wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
assert isinstance(wn_conv.norm, nn.Identity)
assert isinstance(wn_conv.conv, nn.Conv1d)
assert isinstance(gn_conv.norm, nn.GroupNorm)
assert isinstance(gn_conv.conv, nn.Conv1d)
assert isinstance(nn_conv.norm, nn.Identity)
assert isinstance(nn_conv.conv, nn.Conv1d)
for conv_layer in [wn_conv, gn_conv, nn_conv]:
out = conv_layer(t0)
assert isinstance(out, torch.Tensor)
assert list(out.shape) == [N, C_out, expected_out_length]
class TestNormConvTranspose1d:
def test_normalizations(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
C_out, kernel_size, stride = 1, 4, 1
expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
assert isinstance(wn_convtr.norm, nn.Identity)
assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
assert isinstance(gn_convtr.norm, nn.GroupNorm)
assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
assert isinstance(nn_convtr.norm, nn.Identity)
assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
out = convtr_layer(t0)
assert isinstance(out, torch.Tensor)
assert list(out.shape) == [N, C_out, expected_out_length]
class TestStreamableConv1d:
def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
# StreamableConv1d internally pads to make sure that the last window is full
padding_total = (kernel_size - 1) * dilation - (stride - 1)
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length // stride
def test_streamable_conv1d(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
C_out = 1
# conv params are [(kernel_size, stride, dilation)]
conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
out = sconv(t0)
assert isinstance(out, torch.Tensor)
print(list(out.shape), [N, C_out, expected_out_length])
assert list(out.shape) == [N, C_out, expected_out_length]
class TestStreamableConvTranspose1d:
def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
padding_total = (kernel_size - stride)
return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
def test_streamable_convtr1d(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
C_out = 1
with pytest.raises(AssertionError):
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
# causal params are [(causal, trim_right)]
causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
# conv params are [(kernel_size, stride)]
conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
causal=causal, trim_right_ratio=trim_right_ratio)
out = sconvtr(t0)
assert isinstance(out, torch.Tensor)
assert list(out.shape) == [N, C_out, expected_out_length]
================================================
FILE: tests/modules/test_lstm.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from audiocraft.modules.lstm import StreamableLSTM
class TestStreamableLSTM:
def test_lstm(self):
B, C, T = 4, 2, random.randint(1, 100)
lstm = StreamableLSTM(C, 3, skip=False)
x = torch.randn(B, C, T)
y = lstm(x)
print(y.shape)
assert y.shape == torch.Size([B, C, T])
def test_lstm_skip(self):
B, C, T = 4, 2, random.randint(1, 100)
lstm = StreamableLSTM(C, 3, skip=True)
x = torch.randn(B, C, T)
y = lstm(x)
assert y.shape == torch.Size([B, C, T])
================================================
FILE: tests/modules/test_rope.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from audiocraft.modules.rope import RotaryEmbedding
from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend
def test_rope():
set_efficient_attention_backend('xformers')
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C)
xq = torch.rand((B, T, H, C))
xk = torch.rand((B, T, H, C))
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
assert list(xq_out.shape) == [B, T, H, C]
assert list(xk_out.shape) == [B, T, H, C]
def test_rope_io_dtypes():
set_efficient_attention_backend('xformers')
B, T, H, C = 8, 75, 16, 128
rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
# Test bfloat16 inputs w/ both 32 and 64 precision rope.
xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
assert xq_out.dtype == torch.bfloat16
xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
assert xq_out.dtype == torch.bfloat16
# Test float32 inputs w/ both 32 and 64 precision rope.
xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
assert xq_out.dtype == torch.float32
xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
assert xq_out.dtype == torch.float32
def test_transformer_with_rope():
set_efficient_attention_backend('xformers')
torch.manual_seed(1234)
for pos in ['rope', 'sin_rope']:
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
positional_embedding=pos)
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
out = tr(x)
assert list(out.shape) == list(x.shape)
@torch.no_grad()
def test_rope_streaming():
set_efficient_attention_backend('xformers')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, causal=True, dropout=0.,
custom=True, positional_embedding='rope')
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
ref = tr(x)
with tr.streaming():
outs = []
frame_sizes = [1] * steps
for frame_size in frame_sizes:
frame = x[:, :frame_size]
x = x[:, frame_size:]
outs.append(tr(frame))
out = torch.cat(outs, dim=1)
assert list(out.shape) == [3, steps, 16]
delta = torch.norm(out - ref) / torch.norm(out)
assert delta < 1e-6, delta
@torch.no_grad()
def test_rope_streaming_past_context():
set_efficient_attention_backend('xformers')
torch.manual_seed(1234)
for context in [None, 10]:
tr = StreamingTransformer(
16, 4, 1 if context else 2,
causal=True, past_context=context, custom=True,
dropout=0., positional_embedding='rope')
tr.eval()
steps = 20
x = torch.randn(3, steps, 16)
ref = tr(x)
with tr.streaming():
outs = []
frame_sizes = [1] * steps
for frame_size in frame_sizes:
frame = x[:, :frame_size]
x = x[:, frame_size:]
outs.append(tr(frame))
out = torch.cat(outs, dim=1)
assert list(out.shape) == [3, steps, 16]
delta = torch.norm(out - ref) / torch.norm(out)
assert delta < 1e-6, delta
def test_rope_memory_efficient():
set_efficient_attention_backend('xformers')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
positional_embedding='rope')
tr_mem_efficient = StreamingTransformer(
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
positional_embedding='rope')
tr_mem_efficient.load_state_dict(tr.state_dict())
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
with torch.no_grad():
y = tr(x)
y2 = tr_mem_efficient(x)
# Check at float precision b/c this is the rope default.
assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
def test_rope_with_xpos():
set_efficient_attention_backend('xformers')
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C, xpos=True)
xq = torch.rand((B, T, H, C))
xk = torch.rand((B, T, H, C))
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
assert list(xq_out.shape) == [B, T, H, C]
assert list(xk_out.shape) == [B, T, H, C]
def test_positional_scale():
set_efficient_attention_backend('xformers')
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
xq = torch.rand((B, T, H, C))
xk = torch.rand((B, T, H, C))
xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
assert torch.allclose(xq, xq_out)
assert torch.allclose(xk, xk_out)
================================================
FILE: tests/modules/test_seanet.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
import pytest
import torch
from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
class TestSEANetModel:
def test_base(self):
encoder = SEANetEncoder()
decoder = SEANetDecoder()
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def test_causal(self):
encoder = SEANetEncoder(causal=True)
decoder = SEANetDecoder(causal=True)
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def test_conv_skip_connection(self):
encoder = SEANetEncoder(true_skip=False)
decoder = SEANetDecoder(true_skip=False)
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def test_seanet_encoder_decoder_final_act(self):
encoder = SEANetEncoder(true_skip=False)
decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
x = torch.randn(1, 1, 24000)
z = encoder(x)
assert list(z.shape) == [1, 128, 75], z.shape
y = decoder(z)
assert y.shape == x.shape, (x.shape, y.shape)
def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
n_blocks = 0
for layer in encoder.model:
if isinstance(layer, StreamableConv1d):
n_blocks += 1
assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
elif isinstance(layer, SEANetResnetBlock):
for resnet_layer in layer.block:
if isinstance(resnet_layer, StreamableConv1d):
# here we add + 1 to n_blocks as we increment n_blocks just after the block
assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
def test_encoder_disable_norm(self):
n_residuals = [0, 1, 3]
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
norms = ['weight_norm', 'none']
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
disable_norm_outer_blocks=disable_blocks)
self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
n_blocks = 0
for layer in decoder.model:
if isinstance(layer, StreamableConv1d):
n_blocks += 1
assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
elif isinstance(layer, StreamableConvTranspose1d):
n_blocks += 1
assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
elif isinstance(layer, SEANetResnetBlock):
for resnet_layer in layer.block:
if isinstance(resnet_layer, StreamableConv1d):
assert resnet_layer.conv.norm_type == 'none' \
if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
def test_decoder_disable_norm(self):
n_residuals = [0, 1, 3]
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
norms = ['weight_norm', 'none']
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
disable_norm_outer_blocks=disable_blocks)
self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
def test_disable_norm_raises_exception(self):
# Invalid disable_norm_outer_blocks values raise exceptions
with pytest.raises(AssertionError):
SEANetEncoder(disable_norm_outer_blocks=-1)
with pytest.raises(AssertionError):
SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
with pytest.raises(AssertionError):
SEANetDecoder(disable_norm_outer_blocks=-1)
with pytest.raises(AssertionError):
SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
================================================
FILE: tests/modules/test_transformer.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
import pytest
import torch
from audiocraft.modules.transformer import (
StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend)
def test_transformer_causal_streaming():
torch.manual_seed(1234)
for context, custom in product([None, 10], [False, True]):
# Test that causality and receptive fields are properly handled.
# looking at the gradients
tr = StreamingTransformer(
16, 4, 1 if context else 2,
causal=True, past_context=context, custom=custom,
dropout=0.)
steps = 20
for k in [0, 10, 15, 19]:
x = torch.randn(4, steps, 16, requires_grad=True)
y = tr(x)
y[:, k].abs().sum().backward()
if k + 1 < steps:
assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
if context is not None and k > context:
limit = k - context - 1
assert torch.allclose(x.grad[:, :limit],
torch.tensor(0.)), x.grad[:, :limit].norm()
# Now check that streaming gives the same result at batch eval.
x = torch.randn(4, steps, 16)
y = tr(x)
ys = []
with tr.streaming():
for k in range(steps):
chunk = x[:, k:k + 1, :]
ys.append(tr(chunk))
y_stream = torch.cat(ys, dim=1)
delta = torch.norm(y_stream - y) / torch.norm(y)
assert delta < 1e-6, delta
def test_transformer_vs_pytorch():
torch.manual_seed(1234)
# Check that in the non causal setting, we get the same result as
# PyTorch Transformer encoder.
for custom in [False, True]:
tr = StreamingTransformer(
16, 4, 2,
causal=False, custom=custom, dropout=0., positional_scale=0.)
layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
tr_ref = torch.nn.TransformerEncoder(layer, 2)
tr.load_state_dict(tr_ref.state_dict())
x = torch.randn(4, 20, 16)
y = tr(x)
y2 = tr_ref(x)
delta = torch.norm(y2 - y) / torch.norm(y)
assert delta < 1e-6, delta
def test_streaming_api():
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
tr.eval()
steps = 12
x = torch.randn(1, steps, 16)
with torch.no_grad():
with tr.streaming():
_ = tr(x[:, :1])
state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
y = tr(x[:, 1:2])
tr.set_streaming_state(state)
y2 = tr(x[:, 1:2])
assert torch.allclose(y, y2), (y - y2).norm()
assert tr.flush() is None
def test_memory_efficient():
for backend in ['torch', 'xformers']:
torch.manual_seed(1234)
set_efficient_attention_backend(backend)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
tr_mem_efficient = StreamingTransformer(
16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
tr_mem_efficient.load_state_dict(tr.state_dict())
tr.eval()
steps = 12
x = torch.randn(3, steps, 16)
with torch.no_grad():
y = tr(x)
y2 = tr_mem_efficient(x)
assert torch.allclose(y, y2), ((y - y2).norm(), backend)
def test_attention_as_float32():
torch.manual_seed(1234)
cases = [
{'custom': True},
{'custom': False},
]
for case in cases:
tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
tr_float32 = StreamingTransformer(
16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
if not case['custom']:
# we are not using autocast here because it doesn't really
# work as expected on CPU, so we have to manually cast the weights of the MHA.
for layer in tr_float32.layers:
layer.self_attn.mha.to(torch.float32)
tr_float32.load_state_dict(tr.state_dict())
steps = 12
x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
with torch.no_grad():
y = tr(x)
y2 = tr_float32(x)
assert not torch.allclose(y, y2), (y - y2).norm()
@torch.no_grad()
def test_streaming_memory_efficient():
for backend in ['torch', 'xformers']:
torch.manual_seed(1234)
set_efficient_attention_backend(backend)
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
tr_mem_efficient = StreamingTransformer(
16, 4, 2, dropout=0., memory_efficient=True, causal=True)
tr.load_state_dict(tr_mem_efficient.state_dict())
tr.eval()
tr_mem_efficient.eval()
steps = 12
x = torch.randn(3, steps, 16)
ref = tr(x)
with tr_mem_efficient.streaming():
outs = []
# frame_sizes = [2] + [1] * (steps - 2)
frame_sizes = [1] * steps
for frame_size in frame_sizes:
frame = x[:, :frame_size]
x = x[:, frame_size:]
outs.append(tr_mem_efficient(frame))
out = torch.cat(outs, dim=1)
delta = torch.norm(out - ref) / torch.norm(out)
assert delta < 1e-6, delta
def test_cross_attention():
torch.manual_seed(1234)
for norm_first in [True, False]:
m = StreamingTransformer(
16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
m_cross = StreamingTransformer(
16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
m_cross.load_state_dict(m.state_dict(), strict=False)
x = torch.randn(2, 5, 16)
cross_x = torch.randn(2, 3, 16)
y_ref = m(x)
y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
# With norm_first, the two should be exactly the same,
# but with norm_first=False, we get 2 normalization in a row
# and the epsilon value leads to a tiny change.
atol = 0. if norm_first else 1e-6
print((y_ref - y_cross_zero).norm() / y_ref.norm())
assert torch.allclose(y_ref, y_cross_zero, atol=atol)
# We now expect a difference even with a generous atol of 1e-2.
y_cross = m_cross(x, cross_attention_src=cross_x)
assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
with pytest.raises(AssertionError):
_ = m_cross(x)
_ = m(x, cross_attention_src=cross_x)
def test_cross_attention_compat():
torch.manual_seed(1234)
num_heads = 2
dim = num_heads * 64
with pytest.raises(AssertionError):
StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
cross_attn = StreamingMultiheadAttention(
dim, num_heads, dropout=0, cross_attention=True, custom=True)
ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
# We can load the regular attention state dict
# so we have compat when loading old checkpoints.
cross_attn.load_state_dict(ref_attn.state_dict())
queries = torch.randn(3, 7, dim)
keys = torch.randn(3, 9, dim)
values = torch.randn(3, 9, dim)
y = cross_attn(queries, keys, values)[0]
y_ref = ref_attn(queries, keys, values)[0]
assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm()
# Now let's check that streaming is working properly.
with cross_attn.streaming():
ys = []
for step in range(queries.shape[1]):
ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
y_streaming = torch.cat(ys, dim=1)
assert torch.allclose(y_streaming, y, atol=1e-7)
def test_repeat_kv():
torch.manual_seed(1234)
num_heads = 8
kv_repeat = 4
dim = num_heads * 64
with pytest.raises(AssertionError):
mha = StreamingMultiheadAttention(
dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
mha = StreamingMultiheadAttention(
dim, num_heads, causal=True, kv_repeat=kv_repeat)
mha = StreamingMultiheadAttention(
dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
x = torch.randn(4, 18, dim)
y = mha(x, x, x)[0]
assert x.shape == y.shape
def test_qk_layer_norm():
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
steps = 12
x = torch.randn(3, steps, 16)
y = tr(x)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
z = torch.randn(3, 21, 16)
y = tr(x, cross_attention_src=z)
assert y.shape == x.shape
================================================
FILE: tests/quantization/test_vq.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from audiocraft.quantization.vq import ResidualVectorQuantizer
class TestResidualVectorQuantizer:
def test_rvq(self):
x = torch.randn(1, 16, 2048)
vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
res = vq(x, 1.)
assert res.x.shape == torch.Size([1, 16, 2048])
================================================
FILE: tests/utils/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.