Showing preview only (5,836K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/audiocraft
Branch: main
Commit: 896ec7c47f5e
Files: 296
Total size: 5.5 MB
Directory structure:
gitextract_1bqvithb/
├── .github/
│ ├── actions/
│ │ └── audiocraft_build/
│ │ └── action.yml
│ └── workflows/
│ ├── audiocraft_docs.yml
│ ├── audiocraft_linter.yml
│ └── audiocraft_tests.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_weights
├── MANIFEST.in
├── Makefile
├── README.md
├── assets/
│ ├── chord_to_index_mapping.pkl
│ ├── salience_1.th
│ └── salience_2.th
├── audiocraft/
│ ├── __init__.py
│ ├── adversarial/
│ │ ├── __init__.py
│ │ ├── discriminators/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── mpd.py
│ │ │ ├── msd.py
│ │ │ └── msstftd.py
│ │ └── losses.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── audio.py
│ │ ├── audio_dataset.py
│ │ ├── audio_utils.py
│ │ ├── info_audio_dataset.py
│ │ ├── jasco_dataset.py
│ │ ├── music_dataset.py
│ │ ├── sound_dataset.py
│ │ └── zip.py
│ ├── environment.py
│ ├── grids/
│ │ ├── __init__.py
│ │ ├── _base_explorers.py
│ │ ├── audiogen/
│ │ │ ├── __init__.py
│ │ │ ├── audiogen_base_16khz.py
│ │ │ └── audiogen_pretrained_16khz_eval.py
│ │ ├── compression/
│ │ │ ├── __init__.py
│ │ │ ├── _explorers.py
│ │ │ ├── debug.py
│ │ │ ├── encodec_audiogen_16khz.py
│ │ │ ├── encodec_base_24khz.py
│ │ │ └── encodec_musicgen_32khz.py
│ │ ├── diffusion/
│ │ │ ├── 4_bands_base_32khz.py
│ │ │ ├── __init__.py
│ │ │ └── _explorers.py
│ │ ├── magnet/
│ │ │ ├── __init__.py
│ │ │ ├── audio_magnet_16khz.py
│ │ │ ├── audio_magnet_pretrained_16khz_eval.py
│ │ │ ├── magnet_32khz.py
│ │ │ └── magnet_pretrained_32khz_eval.py
│ │ ├── musicgen/
│ │ │ ├── __init__.py
│ │ │ ├── _explorers.py
│ │ │ ├── musicgen_base_32khz.py
│ │ │ ├── musicgen_base_cached_32khz.py
│ │ │ ├── musicgen_clapemb_32khz.py
│ │ │ ├── musicgen_melody_32khz.py
│ │ │ ├── musicgen_pretrained_32khz_eval.py
│ │ │ ├── musicgen_stereo_finetune_32khz.py
│ │ │ └── musicgen_style_32khz.py
│ │ └── watermarking/
│ │ ├── __init__.py
│ │ ├── _explorers.py
│ │ ├── audioseal.py
│ │ └── kbits.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── balancer.py
│ │ ├── loudnessloss.py
│ │ ├── sisnr.py
│ │ ├── specloss.py
│ │ ├── stftloss.py
│ │ └── wmloss.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── chroma_cosinesim.py
│ │ ├── clap_consistency.py
│ │ ├── fad.py
│ │ ├── kld.py
│ │ ├── miou.py
│ │ ├── pesq.py
│ │ ├── rvm.py
│ │ └── visqol.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── audiogen.py
│ │ ├── builders.py
│ │ ├── encodec.py
│ │ ├── flow_matching.py
│ │ ├── genmodel.py
│ │ ├── jasco.py
│ │ ├── lm.py
│ │ ├── lm_magnet.py
│ │ ├── loaders.py
│ │ ├── magnet.py
│ │ ├── multibanddiffusion.py
│ │ ├── musicgen.py
│ │ ├── unet.py
│ │ └── watermark.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── chroma.py
│ │ ├── codebooks_patterns.py
│ │ ├── conditioners.py
│ │ ├── conv.py
│ │ ├── diffusion_schedule.py
│ │ ├── jasco_conditioners.py
│ │ ├── lstm.py
│ │ ├── rope.py
│ │ ├── seanet.py
│ │ ├── streaming.py
│ │ ├── transformer.py
│ │ ├── unet_transformer.py
│ │ └── watermark.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── cosine_lr_scheduler.py
│ │ ├── dadam.py
│ │ ├── ema.py
│ │ ├── fsdp.py
│ │ ├── inverse_sqrt_lr_scheduler.py
│ │ ├── linear_warmup_lr_scheduler.py
│ │ └── polynomial_decay_lr_scheduler.py
│ ├── py.typed
│ ├── quantization/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── core_vq.py
│ │ └── vq.py
│ ├── solvers/
│ │ ├── __init__.py
│ │ ├── audiogen.py
│ │ ├── base.py
│ │ ├── builders.py
│ │ ├── compression.py
│ │ ├── diffusion.py
│ │ ├── jasco.py
│ │ ├── magnet.py
│ │ ├── musicgen.py
│ │ └── watermark.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── audio_effects.py
│ ├── autocast.py
│ ├── best_state.py
│ ├── cache.py
│ ├── checkpoint.py
│ ├── cluster.py
│ ├── deadlock.py
│ ├── export.py
│ ├── export_legacy.py
│ ├── notebook.py
│ ├── profiler.py
│ ├── samples/
│ │ ├── __init__.py
│ │ └── manager.py
│ └── utils.py
├── config/
│ ├── augmentations/
│ │ └── default.yaml
│ ├── conditioner/
│ │ ├── chords2music.yaml
│ │ ├── chroma2music.yaml
│ │ ├── clapemb2music.yaml
│ │ ├── drums2music.yaml
│ │ ├── jasco_chords_drums.yaml
│ │ ├── jasco_chords_drums_melody.yaml
│ │ ├── none.yaml
│ │ ├── style2music.yaml
│ │ ├── text2music.yaml
│ │ └── text2sound.yaml
│ ├── config.yaml
│ ├── dset/
│ │ ├── audio/
│ │ │ ├── audiocaps_16khz.yaml
│ │ │ ├── default.yaml
│ │ │ ├── example.yaml
│ │ │ └── musiccaps_32khz.yaml
│ │ ├── default.yaml
│ │ └── internal/
│ │ ├── music_10k_32khz.yaml
│ │ ├── music_400k_32khz.yaml
│ │ └── sounds_16khz.yaml
│ ├── model/
│ │ ├── encodec/
│ │ │ ├── default.yaml
│ │ │ ├── encodec_base_causal.yaml
│ │ │ ├── encodec_large_nq4_s320.yaml
│ │ │ └── encodec_large_nq4_s640.yaml
│ │ ├── lm/
│ │ │ ├── audiogen_lm.yaml
│ │ │ ├── default.yaml
│ │ │ ├── model_scale/
│ │ │ │ ├── base.yaml
│ │ │ │ ├── large.yaml
│ │ │ │ ├── medium.yaml
│ │ │ │ ├── small.yaml
│ │ │ │ └── xsmall.yaml
│ │ │ └── musicgen_lm.yaml
│ │ ├── none.yaml
│ │ ├── score/
│ │ │ └── basic.yaml
│ │ └── watermark/
│ │ └── default.yaml
│ ├── solver/
│ │ ├── audiogen/
│ │ │ ├── audiogen_base_16khz.yaml
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ └── evaluation/
│ │ │ ├── none.yaml
│ │ │ └── objective_eval.yaml
│ │ ├── compression/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ ├── encodec_audiogen_16khz.yaml
│ │ │ ├── encodec_base_24khz.yaml
│ │ │ └── encodec_musicgen_32khz.yaml
│ │ ├── default.yaml
│ │ ├── diffusion/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ └── encodec_24khz.yaml
│ │ ├── jasco/
│ │ │ ├── chords.yaml
│ │ │ ├── chords_drums.yaml
│ │ │ ├── chords_drums_melody.yaml
│ │ │ ├── drums.yaml
│ │ │ └── jasco_32khz_base.yaml
│ │ ├── magnet/
│ │ │ ├── audio_magnet_16khz.yaml
│ │ │ └── magnet_32khz.yaml
│ │ ├── musicgen/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ ├── evaluation/
│ │ │ │ ├── none.yaml
│ │ │ │ └── objective_eval.yaml
│ │ │ ├── musicgen_base_32khz.yaml
│ │ │ ├── musicgen_melody_32khz.yaml
│ │ │ └── musicgen_style_32khz.yaml
│ │ └── watermark/
│ │ ├── debug.yaml
│ │ ├── default.yaml
│ │ └── robustness.yaml
│ └── teams/
│ ├── default.yaml
│ └── labs.yaml
├── dataset/
│ └── example/
│ ├── electro_1.json
│ └── electro_2.json
├── demos/
│ ├── audiogen_demo.ipynb
│ ├── jasco_app.py
│ ├── jasco_demo.ipynb
│ ├── magnet_app.py
│ ├── magnet_demo.ipynb
│ ├── musicgen_app.py
│ ├── musicgen_demo.ipynb
│ ├── musicgen_style_app.py
│ └── musicgen_style_demo.ipynb
├── docs/
│ ├── AUDIOGEN.md
│ ├── CONDITIONING.md
│ ├── DATASETS.md
│ ├── ENCODEC.md
│ ├── JASCO.md
│ ├── MAGNET.md
│ ├── MBD.md
│ ├── METRICS.md
│ ├── MUSICGEN.md
│ ├── MUSICGEN_STYLE.md
│ ├── TRAINING.md
│ └── WATERMARKING.md
├── egs/
│ └── example/
│ └── data.jsonl
├── jasco_demo.ipynb
├── model_cards/
│ ├── AUDIOGEN_MODEL_CARD.md
│ ├── JASCO_MODEL_CARD.md
│ ├── MAGNET_MODEL_CARD.md
│ ├── MUSICGEN_MODEL_CARD.md
│ └── MUSICGEN_STYLE_MODEL_CARD.md
├── mypy.ini
├── requirements.txt
├── scripts/
│ ├── __init__.py
│ ├── chords/
│ │ ├── build_chord_maps.py
│ │ ├── extract_chords.py
│ │ └── job_array_example.sh
│ ├── mos.py
│ ├── resample_dataset.py
│ ├── static/
│ │ └── style.css
│ └── templates/
│ ├── base.html
│ ├── index.html
│ ├── login.html
│ ├── results.html
│ └── survey.html
├── setup.cfg
├── setup.py
└── tests/
├── __init__.py
├── adversarial/
│ ├── __init__.py
│ ├── test_discriminators.py
│ └── test_losses.py
├── common_utils/
│ ├── __init__.py
│ ├── temp_utils.py
│ └── wav_utils.py
├── data/
│ ├── __init__.py
│ ├── test_audio.py
│ ├── test_audio_dataset.py
│ └── test_audio_utils.py
├── losses/
│ ├── __init__.py
│ └── test_losses.py
├── metrics/
│ ├── __init__.py
│ └── test_pesq.py
├── models/
│ ├── test_audiogen.py
│ ├── test_encodec_model.py
│ ├── test_multibanddiffusion.py
│ ├── test_musicgen.py
│ └── test_watermark.py
├── modules/
│ ├── __init__.py
│ ├── test_activations.py
│ ├── test_codebooks_patterns.py
│ ├── test_conv.py
│ ├── test_lstm.py
│ ├── test_rope.py
│ ├── test_seanet.py
│ └── test_transformer.py
├── quantization/
│ └── test_vq.py
└── utils/
├── __init__.py
└── test_audio_effects.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/actions/audiocraft_build/action.yml
================================================
name: audiocraft_build
description: 'Build audiocraft env.'
runs:
using: "composite"
steps:
- uses: actions/setup-python@v2
with:
python-version: 3.9
- uses: actions/cache@v3
id: cache
with:
path: env
key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
- if: ${{ steps.cache.outputs.cache-hit != 'true' }}
name: Install dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install libsndfile1-dev ffmpeg
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install 'numpy<2' torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
pip install xformers==0.0.22.post7
pip install -e '.[dev,wm]'
- name: System Dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get install libsndfile1-dev ffmpeg
================================================
FILE: .github/workflows/audiocraft_docs.yml
================================================
name: audiocraft_docs
on:
push:
branches: [ main ]
jobs:
run_docs:
name: Run docs
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/audiocraft_build
- name: Config git
run: |
git config --global user.email "defossez@fb.com"
git config --global user.name "Alexandre Défossez (autodoc)"
- name: Reset branch
run: |
git branch -f gh-docs main
git checkout gh-docs
- name: Make docs
run: |
. env/bin/activate
make api_docs
git add -f api_docs
git commit -m api_docs
- name: Push branch
run: |
git push -f -u origin gh-docs
================================================
FILE: .github/workflows/audiocraft_linter.yml
================================================
name: audiocraft_linter
on:
push:
branches: [ main ]
pull_request:
branches: [ main, audiocraft_pub_main ]
jobs:
run_linter:
name: Run linter
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/audiocraft_build
- run: |
. env/bin/activate
make linter
================================================
FILE: .github/workflows/audiocraft_tests.yml
================================================
name: audiocraft_tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main, audiocraft_pub_main ]
jobs:
run_tests:
name: Run tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/audiocraft_build
- name: Run unit tests
run: |
. env/bin/activate
make tests
- name: Run integration tests
run: |
. env/bin/activate
make tests_integ
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__
*.py[cod]
*$py.class
# C extensions
*.so
# macOS dir files
.DS_Store
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.ipynb_checkpoints
# Tests and linter
.pytest_cache/
.mypy_cache/
.coverage
# docs
/api_docs
# dotenv
.env
.envrc
# virtualenv
.venv
venv/
ENV/
# egs with manifest files
egs/*
!egs/example
# local datasets
dataset/*
!dataset/example
# personal notebooks & scripts
*/local_scripts
*/notes
.vscode/
/notebooks
/local_scripts
/notes
================================================
FILE: CHANGELOG.md
================================================
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [1.4.0a2] - 2025-01-14
Add training and inference code for JASCO (https://arxiv.org/abs/2406.10970) along with the [hf checkpoints](https://huggingface.co/facebook/jasco-chords-drums-melody-1B).
## [1.4.0a1] - 2024-06-03
Adding new metric PesqMetric ([Perceptual Evaluation of Speech Quality](https://doi.org/10.5281/zenodo.6549559))
Adding multiple audio augmentation functions: generating pink noises, up-/downsampling, low-/highpass filtering, banpass filtering, smoothing, duck masking, boosting. All are wrapped in the `audiocraft.utils.audio_effects.AudioEffects` and can be called with the API `audiocraft.utils.audio_effects.select_audio_effects`.
Add training code for AudioSeal (https://arxiv.org/abs/2401.17264) along with the [hf checkpoints]( https://huggingface.co/facebook/audioseal).
## [1.3.0] - 2024-05-02
Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app.
Typo fixes.
Fixing setup.py to install only audiocraft, not the unit tests and scripts.
Fix FSDP support with PyTorch 2.1.0.
## [1.2.0] - 2024-01-11
Adding stereo models.
Fixed the commitment loss, which was until now only applied to the first RVQ layer.
Removed compression model state from the LM checkpoints, for consistency, it
should always be loaded from the original `compression_model_checkpoint`.
## [1.1.0] - 2023-11-06
Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
Fixed DAC support with non default number of codebooks.
Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
We removed it, so you might need to retrain models.
**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
retrained a model with this pattern, so hopefully this won't impact you!
## [1.0.0] - 2023-09-07
Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Added pretrained model for AudioGen and MultiBandDiffusion.
## [0.0.2] - 2023-08-01
Improved demo, fixed top p (thanks @jnordberg).
Compressor tanh on output to avoid clipping with some style (especially piano).
Now repeating the conditioning periodically if it is too short.
More options when launching Gradio app locally (thanks @ashleykleynhans).
Testing out PyTorch 2.0 memory efficient attention.
Added extended generation (infinite length) by slowly moving the windows.
Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
## [0.0.1] - 2023-06-09
Initial release, with model evaluation only.
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to AudioCraft
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
AudioCraft is the implementation of a research paper.
Therefore, we do not plan on accepting many pull requests for new features.
We certainly welcome them for bug fixes.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to encodec, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) Meta Platforms, Inc. and affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: LICENSE_weights
================================================
Attribution-NonCommercial 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial 4.0 International Public
License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
================================================
FILE: MANIFEST.in
================================================
include Makefile
include LICENSE
include LICENSE_weights
include *.md
include *.ini
include requirements.txt
include audiocraft/py.typed
include assets/*.mp3
include datasets/*.mp3
recursive-include config *.yaml
recursive-include demos *.py
recursive-include demos *.ipynb
recursive-include scripts *.py
recursive-include model_cards *.md
recursive-include docs *.md
================================================
FILE: Makefile
================================================
INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
dataset.train.num_samples=10 dataset.valid.num_samples=10 \
dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
logging.level=DEBUG
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e
INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
checkpoint.save_last=false # Using compression model from 616d7b3c
INTEG_WATERMARK = AUDIOCRAFT_DORA_DIR="/tmp/wm_$(USER)" dora run device=cpu dataset.num_workers=0 optim.epochs=1 \
dataset.train.num_samples=10 dataset.valid.num_samples=10 dataset.evaluate.num_samples=10 dataset.generate.num_samples=10 \
logging.level=DEBUG solver=watermark/robustness checkpoint.save_last=false dset=audio/example
default: linter tests
install:
pip install -U pip
pip install -U -e '.[dev]'
linter:
flake8 audiocraft && mypy audiocraft
flake8 tests && mypy tests
tests:
coverage run -m pytest tests
coverage report
tests_integ:
$(INTEG_COMPRESSION)
$(INTEG_MBD)
$(INTEG_MUSICGEN)
$(INTEG_AUDIOGEN)
$(INTEG_WATERMARK)
api_docs:
pdoc3 --html -o api_docs -f audiocraft
dist:
python setup.py sdist
.PHONY: linter tests api_docs dist
================================================
FILE: README.md
================================================
# AudioCraft



AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.
## Installation
AudioCraft requires Python 3.9, PyTorch 2.1.0. To install AudioCraft, you can run the following:
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
python -m pip install 'torch==2.1.0'
# You might need the following before trying to install the packages
python -m pip install setuptools wheel
# Then proceed to one of the following
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
python -m pip install -e '.[wm]' # if you want to train a watermarking model
```
We also recommend having `ffmpeg` installed, either through your system or Anaconda:
```bash
sudo apt-get install ffmpeg
# Or if you are using Anaconda or Miniconda
conda install "ffmpeg<5" -c conda-forge
```
## Models
At the moment, AudioCraft contains the training code and inference code for:
* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
* [MAGNeT](./docs/MAGNET.md): A state-of-the-art non-autoregressive model for text-to-music and text-to-sound.
* [AudioSeal](./docs/WATERMARKING.md): A state-of-the-art audio watermarking.
* [MusicGen Style](./docs/MUSICGEN_STYLE.md): A state-of-the-art text-and-style-to-music model.
* [JASCO](./docs/JASCO.md): "High quality text-to-music model conditioned on chords, melodies and drum tracks"
## Training code
AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
the [AudioCraft training documentation](./docs/TRAINING.md).
For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
that provides pointers to configuration, example grids and model/task-specific information and FAQ.
## API documentation
We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.
## FAQ
#### Is the training code available?
Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md),[Multi Band Diffusion](./docs/MBD.md) and [JASCO](./docs/JASCO.md).
#### Where are the models stored?
Hugging Face stored the model in a specific location, which can be overridden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable for the AudioCraft models.
In order to change the cache location of the other Hugging Face models, please check out the [Hugging Face Transformers documentation for the cache setup](https://huggingface.co/docs/transformers/installation#cache-setup).
Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and want to change the download location for Demucs, refer to the [Torch Hub documentation](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved).
## License
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
## Citation
For the general framework of AudioCraft, please cite the following.
```
@inproceedings{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
}
```
When referring to a specific model, please cite as mentioned in the model specific README, e.g
[./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.
================================================
FILE: audiocraft/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
AudioCraft is a general framework for training audio generative models.
At the moment we provide the training code for:
- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
text-to-music and melody+text autoregressive generative model.
For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
`audiocraft.models.musicgen.MusicGen`.
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
text-to-general-audio generative model.
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
neural audio codec which provides an excellent tokenizer for autoregressive language models.
See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
improves the perceived quality and reduces the artifacts coming from adversarial decoders.
- [JASCO](https://arxiv.org/abs/2406.10970) Joint Audio and Symbolic Conditioning for Temporally Controlled
Text-to-Music Generation.
"""
# flake8: noqa
from . import data, modules, models
__version__ = '1.4.0a2'
================================================
FILE: audiocraft/adversarial/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Adversarial losses and discriminator architectures."""
# flake8: noqa
from .discriminators import (
MultiPeriodDiscriminator,
MultiScaleDiscriminator,
MultiScaleSTFTDiscriminator
)
from .losses import (
AdversarialLoss,
AdvLossType,
get_adv_criterion,
get_fake_criterion,
get_real_criterion,
FeatLossType,
FeatureMatchingLoss
)
================================================
FILE: audiocraft/adversarial/discriminators/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .mpd import MultiPeriodDiscriminator
from .msd import MultiScaleDiscriminator
from .msstftd import MultiScaleSTFTDiscriminator
================================================
FILE: audiocraft/adversarial/discriminators/base.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import typing as tp
import torch
import torch.nn as nn
FeatureMapType = tp.List[torch.Tensor]
LogitsType = torch.Tensor
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
class MultiDiscriminator(ABC, nn.Module):
"""Base implementation for discriminators composed of sub-discriminators acting at different scales.
"""
def __init__(self):
super().__init__()
@abstractmethod
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
...
@property
@abstractmethod
def num_discriminators(self) -> int:
"""Number of discriminators.
"""
...
================================================
FILE: audiocraft/adversarial/discriminators/mpd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
class PeriodDiscriminator(nn.Module):
"""Period sub-discriminator.
Args:
period (int): Period between samples of audio.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_layers (int): Number of convolutional layers.
kernel_sizes (list of int): Kernel sizes for convolutions.
stride (int): Stride for convolutions.
filters (int): Initial number of filters in convolutions.
filters_scale (int): Multiplier of number of filters as we increase depth.
max_filters (int): Maximum number of filters.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
"""
def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
self.period = period
self.n_layers = n_layers
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList()
in_chs = in_channels
for i in range(self.n_layers):
out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
eff_stride = 1 if i == self.n_layers - 1 else stride
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
in_chs = out_chs
self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), 'reflect')
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for conv in self.convs:
x = conv(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
# x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(MultiDiscriminator):
"""Multi-Period (MPD) Discriminator.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
**kwargs: Additional args for `PeriodDiscriminator`
"""
def __init__(self, in_channels: int = 1, out_channels: int = 1,
periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
super().__init__()
self.discriminators = nn.ModuleList([
PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
])
@property
def num_discriminators(self):
return len(self.discriminators)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
================================================
FILE: audiocraft/adversarial/discriminators/msd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import numpy as np
import torch
import torch.nn as nn
from ...modules import NormConv1d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
class ScaleDiscriminator(nn.Module):
"""Waveform sub-discriminator.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
filters (int): Number of initial filters for convolutions.
max_filters (int): Maximum number of filters.
downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
groups (Sequence[int] or None): Groups for inner convolutions.
strides (Sequence[int] or None): Strides for inner convolutions.
paddings (Sequence[int] or None): Paddings for inner convolutions.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
pad (str): Padding for initial convolution.
pad_params (dict): Parameters to provide to the padding module.
"""
def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
norm: str = 'weight_norm', activation: str = 'LeakyReLU',
activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
pad_params: dict = {}):
super().__init__()
assert len(kernel_sizes) == 2
assert kernel_sizes[0] % 2 == 1
assert kernel_sizes[1] % 2 == 1
assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
assert (groups is None or len(groups) == len(downsample_scales))
assert (strides is None or len(strides) == len(downsample_scales))
assert (paddings is None or len(paddings) == len(downsample_scales))
self.activation = getattr(torch.nn, activation)(**activation_params)
self.convs = nn.ModuleList()
self.convs.append(
nn.Sequential(
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
)
)
in_chs = filters
for i, downsample_scale in enumerate(downsample_scales):
out_chs = min(in_chs * downsample_scale, max_filters)
default_kernel_size = downsample_scale * 10 + 1
default_stride = downsample_scale
default_padding = (default_kernel_size - 1) // 2
default_groups = in_chs // 4
self.convs.append(
NormConv1d(in_chs, out_chs,
kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
stride=strides[i] if strides else default_stride,
groups=groups[i] if groups else default_groups,
padding=paddings[i] if paddings else default_padding,
norm=norm))
in_chs = out_chs
out_chs = min(in_chs * 2, max_filters)
self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
padding=(kernel_sizes[0] - 1) // 2, norm=norm))
self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
padding=(kernel_sizes[1] - 1) // 2, norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
for layer in self.convs:
x = layer(x)
x = self.activation(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
# x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(MultiDiscriminator):
"""Multi-Scale (MSD) Discriminator,
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
downsample_factor (int): Downsampling factor between the different scales.
scale_norms (Sequence[str]): Normalization for each sub-discriminator.
**kwargs: Additional args for ScaleDiscriminator.
"""
def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
super().__init__()
self.discriminators = nn.ModuleList([
ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
])
self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
@property
def num_discriminators(self):
return len(self.discriminators)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for i, disc in enumerate(self.discriminators):
if i != 0:
self.downsample(x)
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
================================================
FILE: audiocraft/adversarial/discriminators/msstftd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import torchaudio
import torch
from torch import nn
from einops import rearrange
from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
class DiscriminatorSTFT(nn.Module):
"""STFT sub-discriminator.
Args:
filters (int): Number of filters in convolutions.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_fft (int): Size of FFT for each scale.
hop_length (int): Length of hop between STFT windows for each scale.
kernel_size (tuple of int): Inner Conv2d kernel sizes.
stride (tuple of int): Inner Conv2d strides.
dilations (list of int): Inner Conv2d dilation on the time dimension.
win_length (int): Window size for each scale.
normalized (bool): Whether to normalize by magnitude after stft.
norm (str): Normalization method.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
growth (int): Growth factor for the filters.
"""
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
super().__init__()
assert len(kernel_size) == 2
assert len(stride) == 2
self.filters = filters
self.in_channels = in_channels
self.out_channels = out_channels
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized
self.activation = getattr(torch.nn, activation)(**activation_params)
self.spec_transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
normalized=self.normalized, center=False, pad_mode=None, power=None)
spec_channels = 2 * self.in_channels
self.convs = nn.ModuleList()
self.convs.append(
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
)
in_chs = min(filters_scale * self.filters, max_filters)
for i, dilation in enumerate(dilations):
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
norm=norm))
in_chs = out_chs
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm))
self.conv_post = NormConv2d(out_chs, self.out_channels,
kernel_size=(kernel_size[0], kernel_size[0]),
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
norm=norm)
def forward(self, x: torch.Tensor):
fmap = []
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
z = torch.cat([z.real, z.imag], dim=1)
z = rearrange(z, 'b c w t -> b c t w')
for i, layer in enumerate(self.convs):
z = layer(z)
z = self.activation(z)
fmap.append(z)
z = self.conv_post(z)
return z, fmap
class MultiScaleSTFTDiscriminator(MultiDiscriminator):
"""Multi-Scale STFT (MS-STFT) discriminator.
Args:
filters (int): Number of filters in convolutions.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
sep_channels (bool): Separate channels to distinct samples for stereo support.
n_ffts (Sequence[int]): Size of FFT for each scale.
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
win_lengths (Sequence[int]): Window size for each scale.
**kwargs: Additional args for STFTDiscriminator.
"""
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
super().__init__()
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
self.sep_channels = sep_channels
self.discriminators = nn.ModuleList([
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
for i in range(len(n_ffts))
])
@property
def num_discriminators(self):
return len(self.discriminators)
def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
B, C, T = x.shape
return x.view(-1, 1, T)
def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
logits = []
fmaps = []
for disc in self.discriminators:
logit, fmap = disc(x)
logits.append(logit)
fmaps.append(fmap)
return logits, fmaps
================================================
FILE: audiocraft/adversarial/losses.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility module to handle adversarial losses without requiring to mess up the main training loop.
"""
import typing as tp
import flashy
import torch
import torch.nn as nn
import torch.nn.functional as F
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
class AdversarialLoss(nn.Module):
"""Adversary training wrapper.
Args:
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
where the first item is a list of logits and the second item is a list of feature maps.
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
loss (AdvLossType): Loss function for generator training.
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
loss_feat (FeatLossType): Feature matching loss function for generator training.
normalize (bool): Whether to normalize by number of sub-discriminators.
Example of usage:
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
for real in loader:
noise = torch.randn(...)
fake = model(noise)
adv_loss.train_adv(fake, real)
loss, _ = adv_loss(fake, real)
loss.backward()
"""
def __init__(self,
adversary: nn.Module,
optimizer: torch.optim.Optimizer,
loss: AdvLossType,
loss_real: AdvLossType,
loss_fake: AdvLossType,
loss_feat: tp.Optional[FeatLossType] = None,
normalize: bool = True):
super().__init__()
self.adversary: nn.Module = adversary
flashy.distrib.broadcast_model(self.adversary)
self.optimizer = optimizer
self.loss = loss
self.loss_real = loss_real
self.loss_fake = loss_fake
self.loss_feat = loss_feat
self.normalize = normalize
def _save_to_state_dict(self, destination, prefix, keep_vars):
# Add the optimizer state dict inside our own.
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
return destination
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Load optimizer state.
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_adversary_pred(self, x):
"""Run adversary model, validating expected output format."""
logits, fmaps = self.adversary(x)
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
f'Expecting a list of tensors as logits but {type(logits)} found.'
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
for fmap in fmaps:
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
return logits, fmaps
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
"""Train the adversary with the given fake and real example.
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
and call the optimizer.
"""
loss = torch.tensor(0., device=fake.device)
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
if self.normalize:
loss /= n_sub_adversaries
self.optimizer.zero_grad()
with flashy.distrib.eager_sync_model(self.adversary):
loss.backward()
self.optimizer.step()
return loss
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Return the loss for the generator, i.e. trying to fool the adversary,
and feature matching loss if provided.
"""
adv = torch.tensor(0., device=fake.device)
feat = torch.tensor(0., device=fake.device)
with flashy.utils.readonly(self.adversary):
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake in all_logits_fake_is_fake:
adv += self.loss(logit_fake_is_fake)
if self.loss_feat:
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
feat += self.loss_feat(fmap_fake, fmap_real)
if self.normalize:
adv /= n_sub_adversaries
feat /= n_sub_adversaries
return adv, feat
def get_adv_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_loss
elif loss_type == 'hinge':
return hinge_loss
elif loss_type == 'hinge2':
return hinge2_loss
raise ValueError('Unsupported loss')
def get_fake_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_fake_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_fake_loss
raise ValueError('Unsupported loss')
def get_real_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_real_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_real_loss
raise ValueError('Unsupported loss')
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def mse_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return -x.mean()
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0])
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
class FeatureMatchingLoss(nn.Module):
"""Feature matching loss for adversarial training.
Args:
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
normalize (bool): Whether to normalize the loss.
by number of feature maps.
"""
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
super().__init__()
self.loss = loss
self.normalize = normalize
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
n_fmaps = 0
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
assert feat_fake.shape == feat_real.shape
n_fmaps += 1
feat_loss += self.loss(feat_fake, feat_real)
feat_scale += torch.mean(torch.abs(feat_real))
if self.normalize:
feat_loss /= n_fmaps
return feat_loss
================================================
FILE: audiocraft/data/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Audio loading and writing support. Datasets for raw audio
or also including some metadata."""
# flake8: noqa
from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset, jasco_dataset
================================================
FILE: audiocraft/data/audio.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Audio IO methods are defined in this module (info, read, write),
We rely on av library for faster read when possible, otherwise on torchaudio.
"""
from dataclasses import dataclass
from pathlib import Path
import logging
import typing as tp
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
import av
import subprocess as sp
from .audio_utils import f32_pcm, normalize_audio
_av_initialized = False
def _init_av():
global _av_initialized
if _av_initialized:
return
logger = logging.getLogger('libav.mp3')
logger.setLevel(logging.ERROR)
_av_initialized = True
@dataclass(frozen=True)
class AudioFileInfo:
sample_rate: int
duration: float
channels: int
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
_init_av()
with av.open(str(filepath)) as af:
stream = af.streams.audio[0]
sample_rate = stream.codec_context.sample_rate
duration = float(stream.duration * stream.time_base)
channels = stream.channels
return AudioFileInfo(sample_rate, duration, channels)
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
info = soundfile.info(filepath)
return AudioFileInfo(info.samplerate, info.duration, info.channels)
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
# torchaudio no longer returns useful duration informations for some formats like mp3s.
filepath = Path(filepath)
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
# ffmpeg has some weird issue with flac.
return _soundfile_info(filepath)
else:
return _av_info(filepath)
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
"""FFMPEG-based audio file reading using PyAV bindings.
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
Args:
filepath (str or Path): Path to audio file to read.
seek_time (float): Time at which to start reading in the file.
duration (float): Duration to read from the file. If set to -1, the whole file is read.
Returns:
tuple of torch.Tensor, int: Tuple containing audio data and sample rate
"""
_init_av()
with av.open(str(filepath)) as af:
stream = af.streams.audio[0]
sr = stream.codec_context.sample_rate
num_frames = int(sr * duration) if duration >= 0 else -1
frame_offset = int(sr * seek_time)
# we need a small negative offset otherwise we get some edge artifact
# from the mp3 decoder.
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
frames = []
length = 0
for frame in af.decode(streams=stream.index):
current_offset = int(frame.rate * frame.pts * frame.time_base)
strip = max(0, frame_offset - current_offset)
buf = torch.from_numpy(frame.to_ndarray())
if buf.shape[0] != stream.channels:
buf = buf.view(-1, stream.channels).t()
buf = buf[:, strip:]
frames.append(buf)
length += buf.shape[1]
if num_frames > 0 and length >= num_frames:
break
assert frames
# If the above assert fails, it is likely because we seeked past the end of file point,
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
# This will need proper debugging, in due time.
wav = torch.cat(frames, dim=1)
assert wav.shape[0] == stream.channels
if num_frames > 0:
wav = wav[:, :num_frames]
return f32_pcm(wav), sr
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
"""Read audio by picking the most appropriate backend tool based on the audio format.
Args:
filepath (str or Path): Path to audio file to read.
seek_time (float): Time at which to start reading in the file.
duration (float): Duration to read from the file. If set to -1, the whole file is read.
pad (bool): Pad output audio if not reaching expected duration.
Returns:
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
"""
fp = Path(filepath)
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
# There is some bug with ffmpeg and reading flac
info = _soundfile_info(filepath)
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
frame_offset = int(seek_time * info.sample_rate)
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
expected_frames = int(duration * sr)
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
return wav, sr
def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
assert wav.dim() == 2, wav.shape
command = [
'ffmpeg',
'-loglevel', 'error',
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
'-i', '-'] + flags + [str(out_path)]
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
sp.run(command, input=input_, check=True)
def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
add_suffix: bool = True) -> Path:
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
Args:
stem_name (str or Path): Filename without extension which will be added automatically.
wav (torch.Tensor): Audio data to save.
sample_rate (int): Sample rate of audio data.
format (str): Either "wav", "mp3", "ogg", or "flac".
mp3_rate (int): kbps when using mp3s.
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
make_parent_dir (bool): Make parent directory if it doesn't exist.
Returns:
Path: Path of the saved audio.
"""
assert wav.dtype.is_floating_point, "wav is not floating point"
if wav.dim() == 1:
wav = wav[None]
elif wav.dim() > 2:
raise ValueError("Input wav should be at most 2 dimension.")
assert wav.isfinite().all()
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
if format == 'mp3':
suffix = '.mp3'
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
elif format == 'wav':
suffix = '.wav'
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
elif format == 'ogg':
suffix = '.ogg'
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
if ogg_rate is not None:
flags += ['-b:a', f'{ogg_rate}k']
elif format == 'flac':
suffix = '.flac'
flags = ['-f', 'flac']
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
suffix = ''
path = Path(str(stem_name) + suffix)
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
_piping_to_ffmpeg(path, wav, sample_rate, flags)
except Exception:
if path.exists():
# we do not want to leave half written files around.
path.unlink()
raise
return path
def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray:
"""Get the mel-spectrogram from the raw audio.
Args:
y (numpy array): raw input
sr (int): Sampling rate
n_fft (int): Number of samples per FFT. Default is 2048.
hop_length (int): Number of samples between successive frames. Default is 512.
dur (float): Maxium duration to get the spectrograms
Returns:
spectro histogram as a numpy array
"""
import librosa
import librosa.display
spectrogram = librosa.feature.melspectrogram(
y=y, sr=sr, n_fft=n_fft, hop_length=hop_length
)
spectrogram_db = librosa.power_to_db(spectrogram, ref=np.max)
return spectrogram_db
def save_spectrograms(
ys: tp.List[np.ndarray],
sr: int,
path: str,
names: tp.List[str],
n_fft: int = 4096,
hop_length: int = 128,
dur: float = 8.0,
):
"""Plot a spectrogram for an audio file.
Args:
ys: List of audio spectrograms
sr (int): Sampling rate of the audio file. Default is 22050 Hz.
path (str): Path to the plot file.
names: name of each spectrogram plot
n_fft (int): Number of samples per FFT. Default is 2048.
hop_length (int): Number of samples between successive frames. Default is 512.
dur (float): Maxium duration to plot the spectrograms
Returns:
None (plots the spectrogram using matplotlib)
"""
import matplotlib as mpl # type: ignore
import matplotlib.pyplot as plt # type: ignore
import librosa.display
if not names:
names = ["Ground Truth", "Audio Watermarked", "Watermark"]
ys = [wav[: int(dur * sr)] for wav in ys] # crop
assert len(names) == len(
ys
), f"There are {len(ys)} wavs but {len(names)} names ({names})"
# Set matplotlib stuff
BIGGER_SIZE = 10
SMALLER_SIZE = 8
linewidth = 234.8775 # linewidth in pt
plt.rc("font", size=BIGGER_SIZE, family="serif") # controls default text sizes
plt.rcParams["font.family"] = "DeJavu Serif"
plt.rcParams["font.serif"] = ["Times New Roman"]
plt.rc("axes", titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc("axes", labelsize=BIGGER_SIZE) # fontsize of the x and y labels
plt.rc("xtick", labelsize=BIGGER_SIZE) # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALLER_SIZE) # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE) # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE)
height = 1.6 * linewidth / 72.0
fig, ax = plt.subplots(
nrows=len(ys),
ncols=1,
sharex=True,
figsize=(linewidth / 72.0, height),
)
fig.tight_layout()
# Plot the spectrogram
for i, ysi in enumerate(ys):
spectrogram_db = get_spec(ysi, sr=sr, n_fft=n_fft, hop_length=hop_length)
if i == 0:
cax = fig.add_axes(
[
ax[0].get_position().x1 + 0.01, # type: ignore
ax[-1].get_position().y0,
0.02,
ax[0].get_position().y1 - ax[-1].get_position().y0,
]
)
fig.colorbar(
mpl.cm.ScalarMappable(
norm=mpl.colors.Normalize(
np.min(spectrogram_db), np.max(spectrogram_db)
),
cmap="magma",
),
ax=ax,
orientation="vertical",
format="%+2.0f dB",
cax=cax,
)
librosa.display.specshow(
spectrogram_db,
sr=sr,
hop_length=hop_length,
x_axis="time",
y_axis="mel",
ax=ax[i],
)
ax[i].set(title=names[i])
ax[i].yaxis.set_label_text(None)
ax[i].label_outer()
fig.savefig(path, bbox_inches="tight")
plt.close()
================================================
FILE: audiocraft/data/audio_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""AudioDataset support. In order to handle a larger number of files
without having to scan again the folders, we precompute some metadata
(filename, sample rate, duration), and use that to efficiently sample audio segments.
"""
import argparse
import copy
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, fields
from contextlib import ExitStack
from functools import lru_cache
import gzip
import json
import logging
import os
from pathlib import Path
import random
import sys
import typing as tp
import torch
import torch.nn.functional as F
from .audio import audio_read, audio_info
from .audio_utils import convert_audio
from .zip import PathInZip
try:
import dora
except ImportError:
dora = None # type: ignore
@dataclass(order=True)
class BaseInfo:
@classmethod
def _dict2fields(cls, dictionary: dict):
return {
field.name: dictionary[field.name]
for field in fields(cls) if field.name in dictionary
}
@classmethod
def from_dict(cls, dictionary: dict):
_dictionary = cls._dict2fields(dictionary)
return cls(**_dictionary)
def to_dict(self):
return {
field.name: self.__getattribute__(field.name)
for field in fields(self)
}
@dataclass(order=True)
class AudioMeta(BaseInfo):
path: str
duration: float
sample_rate: int
amplitude: tp.Optional[float] = None
weight: tp.Optional[float] = None
# info_path is used to load additional information about the audio file that is stored in zip files.
info_path: tp.Optional[PathInZip] = None
@classmethod
def from_dict(cls, dictionary: dict):
base = cls._dict2fields(dictionary)
if 'info_path' in base and base['info_path'] is not None:
base['info_path'] = PathInZip(base['info_path'])
return cls(**base)
def to_dict(self):
d = super().to_dict()
if d['info_path'] is not None:
d['info_path'] = str(d['info_path'])
return d
@dataclass(order=True)
class SegmentInfo(BaseInfo):
meta: AudioMeta
seek_time: float
# The following values are given once the audio is processed, e.g.
# at the target sample rate and target number of channels.
n_frames: int # actual number of frames without padding
total_frames: int # total number of frames, padding included
sample_rate: int # actual sample rate
channels: int # number of audio channels.
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
logger = logging.getLogger(__name__)
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
"""AudioMeta from a path to an audio file.
Args:
file_path (str): Resolved path of valid audio file.
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
Returns:
AudioMeta: Audio file path and its metadata.
"""
info = audio_info(file_path)
amplitude: tp.Optional[float] = None
if not minimal:
wav, sr = audio_read(file_path)
amplitude = wav.abs().max().item()
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
"""If Dora is available as a dependency, try to resolve potential relative paths
in list of AudioMeta. This method is expected to be used when loading meta from file.
Args:
m (AudioMeta): Audio meta to resolve.
fast (bool): If True, uses a really fast check for determining if a file
is already absolute or not. Only valid on Linux/Mac.
Returns:
AudioMeta: Audio meta with resolved path.
"""
def is_abs(m):
if fast:
return str(m)[0] == '/'
else:
os.path.isabs(str(m))
if not dora:
return m
if not is_abs(m.path):
m.path = dora.git_save.to_absolute_path(m.path)
if m.info_path is not None and not is_abs(m.info_path.zip_path):
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
return m
def find_audio_files(path: tp.Union[Path, str],
exts: tp.List[str] = DEFAULT_EXTS,
resolve: bool = True,
minimal: bool = True,
progress: bool = False,
workers: int = 0) -> tp.List[AudioMeta]:
"""Build a list of AudioMeta from a given path,
collecting relevant audio files and fetching meta info.
Args:
path (str or Path): Path to folder containing audio files.
exts (list of str): List of file extensions to consider for audio files.
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
progress (bool): Whether to log progress on audio files collection.
workers (int): number of parallel workers, if 0, use only the current thread.
Returns:
list of AudioMeta: List of audio file path and its metadata.
"""
audio_files = []
futures: tp.List[Future] = []
pool: tp.Optional[ThreadPoolExecutor] = None
with ExitStack() as stack:
if workers > 0:
pool = ThreadPoolExecutor(workers)
stack.enter_context(pool)
if progress:
print("Finding audio files...")
for root, folders, files in os.walk(path, followlinks=True):
for file in files:
full_path = Path(root) / file
if full_path.suffix.lower() in exts:
audio_files.append(full_path)
if pool is not None:
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
if progress:
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
if progress:
print("Getting audio metadata...")
meta: tp.List[AudioMeta] = []
for idx, file_path in enumerate(audio_files):
try:
if pool is None:
m = _get_audio_meta(str(file_path), minimal)
else:
m = futures[idx].result()
if resolve:
m = _resolve_audio_meta(m)
except Exception as err:
print("Error with", str(file_path), err, file=sys.stderr)
continue
meta.append(m)
if progress:
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
meta.sort()
return meta
def load_audio_meta(path: tp.Union[str, Path],
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
"""Load list of AudioMeta from an optionally compressed json file.
Args:
path (str or Path): Path to JSON file.
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
fast (bool): activates some tricks to make things faster.
Returns:
list of AudioMeta: List of audio file path and its total duration.
"""
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
with open_fn(path, 'rb') as fp: # type: ignore
lines = fp.readlines()
meta = []
for line in lines:
d = json.loads(line)
m = AudioMeta.from_dict(d)
if resolve:
m = _resolve_audio_meta(m, fast=fast)
meta.append(m)
return meta
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
"""Save the audio metadata to the file pointer as json.
Args:
path (str or Path): Path to JSON file.
metadata (list of BaseAudioMeta): List of audio meta to save.
"""
Path(path).parent.mkdir(exist_ok=True, parents=True)
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
with open_fn(path, 'wb') as fp: # type: ignore
for m in meta:
json_str = json.dumps(m.to_dict()) + '\n'
json_bytes = json_str.encode('utf-8')
fp.write(json_bytes)
class AudioDataset:
"""Base audio dataset.
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
and potentially additional information, by creating random segments from the list of audio
files referenced in the metadata and applying minimal data pre-processing such as resampling,
mixing of channels, padding, etc.
If no segment_duration value is provided, the AudioDataset will return the full wav for each
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
duration, applying padding if required.
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
original audio meta.
Note that you can call `start_epoch(epoch)` in order to get
a deterministic "randomization" for `shuffle=True`.
For a given epoch and dataset index, this will always return the same extract.
You can get back some diversity by setting the `shuffle_seed` param.
Args:
meta (list of AudioMeta): List of audio files metadata.
segment_duration (float, optional): Optional segment duration of audio to load.
If not specified, the dataset will load the full audio segment from the file.
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
sample_rate (int): Target sample rate of the loaded audio samples.
channels (int): Target number of channels of the loaded audio samples.
sample_on_duration (bool): Set to `True` to sample segments with probability
dependent on audio file duration. This is only used if `segment_duration` is provided.
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
of the file duration and file weight. This is only used if `segment_duration` is provided.
min_segment_ratio (float): Minimum segment ratio to use when the audio file
is shorter than the desired segment.
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
audio shorter than this will be filtered out.
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
audio longer than this will be filtered out.
shuffle_seed (int): can be used to further randomize
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
with the expected segment_duration (which must be provided if load_wav is False).
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
are False. Will ensure a permutation on files when going through the dataset.
In that case the epoch number must be provided in order for the model
to continue the permutation across epochs. In that case, it is assumed
that `num_samples = total_batch_size * num_updates_per_epoch`, with
`total_batch_size` the overall batch size accounting for all gpus.
"""
def __init__(self,
meta: tp.List[AudioMeta],
segment_duration: tp.Optional[float] = None,
shuffle: bool = True,
num_samples: int = 10_000,
sample_rate: int = 48_000,
channels: int = 2,
pad: bool = True,
sample_on_duration: bool = True,
sample_on_weight: bool = True,
min_segment_ratio: float = 0.5,
max_read_retry: int = 10,
return_info: bool = False,
min_audio_duration: tp.Optional[float] = None,
max_audio_duration: tp.Optional[float] = None,
shuffle_seed: int = 0,
load_wav: bool = True,
permutation_on_files: bool = False,
):
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
assert segment_duration is None or segment_duration > 0
assert segment_duration is None or min_segment_ratio >= 0
self.segment_duration = segment_duration
self.min_segment_ratio = min_segment_ratio
self.max_audio_duration = max_audio_duration
self.min_audio_duration = min_audio_duration
if self.min_audio_duration is not None and self.max_audio_duration is not None:
assert self.min_audio_duration <= self.max_audio_duration
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
assert len(self.meta) # Fail fast if all data has been filtered.
self.total_duration = sum(d.duration for d in self.meta)
if segment_duration is None:
num_samples = len(self.meta)
self.num_samples = num_samples
self.shuffle = shuffle
self.sample_rate = sample_rate
self.channels = channels
self.pad = pad
self.sample_on_weight = sample_on_weight
self.sample_on_duration = sample_on_duration
self.sampling_probabilities = self._get_sampling_probabilities()
self.max_read_retry = max_read_retry
self.return_info = return_info
self.shuffle_seed = shuffle_seed
self.current_epoch: tp.Optional[int] = None
self.load_wav = load_wav
if not load_wav:
assert segment_duration is not None
self.permutation_on_files = permutation_on_files
if permutation_on_files:
assert not self.sample_on_duration
assert not self.sample_on_weight
assert self.shuffle
def start_epoch(self, epoch: int):
self.current_epoch = epoch
def __len__(self):
return self.num_samples
def _get_sampling_probabilities(self, normalized: bool = True):
"""Return the sampling probabilities for each file inside `self.meta`."""
scores: tp.List[float] = []
for file_meta in self.meta:
score = 1.
if self.sample_on_weight and file_meta.weight is not None:
score *= file_meta.weight
if self.sample_on_duration:
score *= file_meta.duration
scores.append(score)
probabilities = torch.tensor(scores)
if normalized:
probabilities /= probabilities.sum()
return probabilities
@staticmethod
@lru_cache(16)
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
# Used to keep the most recent files permutation in memory implicitely.
# will work unless someone is using a lot of Datasets in parallel.
rng = torch.Generator()
rng.manual_seed(base_seed + permutation_index)
return torch.randperm(num_files, generator=rng)
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
This is only called if `segment_duration` is not None.
You must use the provided random number generator `rng` for reproducibility.
You can further make use of the index accessed.
"""
if self.permutation_on_files:
assert self.current_epoch is not None
total_index = self.current_epoch * len(self) + index
permutation_index = total_index // len(self.meta)
relative_index = total_index % len(self.meta)
permutation = AudioDataset._get_file_permutation(
len(self.meta), permutation_index, self.shuffle_seed)
file_index = permutation[relative_index]
return self.meta[file_index]
if not self.sample_on_weight and not self.sample_on_duration:
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
else:
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
return self.meta[file_index]
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
# Override this method in subclass if needed.
if self.load_wav:
return audio_read(path, seek_time, duration, pad=False)
else:
assert self.segment_duration is not None
n_frames = int(self.sample_rate * self.segment_duration)
return torch.zeros(self.channels, n_frames), self.sample_rate
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
if self.segment_duration is None:
file_meta = self.meta[index]
out, sr = audio_read(file_meta.path)
out = convert_audio(out, sr, self.sample_rate, self.channels)
n_frames = out.shape[-1]
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
sample_rate=self.sample_rate, channels=out.shape[0])
else:
rng = torch.Generator()
if self.shuffle:
# We use index, plus extra randomness, either totally random if we don't know the epoch.
# otherwise we make use of the epoch number and optional shuffle_seed.
if self.current_epoch is None:
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
else:
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
else:
# We only use index
rng.manual_seed(index)
for retry in range(self.max_read_retry):
file_meta = self.sample_file(index, rng)
# We add some variance in the file position even if audio file is smaller than segment
# without ending up with empty segments
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
seek_time = torch.rand(1, generator=rng).item() * max_seek
try:
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
out = convert_audio(out, sr, self.sample_rate, self.channels)
n_frames = out.shape[-1]
target_frames = int(self.segment_duration * self.sample_rate)
if self.pad:
out = F.pad(out, (0, target_frames - n_frames))
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
sample_rate=self.sample_rate, channels=out.shape[0])
except Exception as exc:
logger.warning("Error opening file %s: %r", file_meta.path, exc)
if retry == self.max_read_retry - 1:
raise
else:
break
if self.return_info:
# Returns the wav and additional information on the wave segment
return out, segment_info
else:
return out
def collater(self, samples):
"""The collater function has to be provided to the dataloader
if AudioDataset has return_info=True in order to properly collate
the samples of a batch.
"""
if self.segment_duration is None and len(samples) > 1:
assert self.pad, "Must allow padding when batching examples of different durations."
# In this case the audio reaching the collater is of variable length as segment_duration=None.
to_pad = self.segment_duration is None and self.pad
if to_pad:
max_len = max([wav.shape[-1] for wav, _ in samples])
def _pad_wav(wav):
return F.pad(wav, (0, max_len - wav.shape[-1]))
if self.return_info:
if len(samples) > 0:
assert len(samples[0]) == 2
assert isinstance(samples[0][0], torch.Tensor)
assert isinstance(samples[0][1], SegmentInfo)
wavs = [wav for wav, _ in samples]
segment_infos = [copy.deepcopy(info) for _, info in samples]
if to_pad:
# Each wav could be of a different duration as they are not segmented.
for i in range(len(samples)):
# Determines the total length of the signal with padding, so we update here as we pad.
segment_infos[i].total_frames = max_len
wavs[i] = _pad_wav(wavs[i])
wav = torch.stack(wavs)
return wav, segment_infos
else:
assert isinstance(samples[0], torch.Tensor)
if to_pad:
samples = [_pad_wav(s) for s in samples]
return torch.stack(samples)
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
orig_len = len(meta)
# Filter data that is too short.
if self.min_audio_duration is not None:
meta = [m for m in meta if m.duration >= self.min_audio_duration]
# Filter data that is too long.
if self.max_audio_duration is not None:
meta = [m for m in meta if m.duration <= self.max_audio_duration]
filtered_len = len(meta)
removed_percentage = 100*(1-float(filtered_len)/orig_len)
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
if removed_percentage < 10:
logging.debug(msg)
else:
logging.warning(msg)
return meta
@classmethod
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
Args:
root (str or Path): Path to root folder containing audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
if root.is_dir():
if (root / 'data.jsonl').exists():
root = root / 'data.jsonl'
elif (root / 'data.jsonl.gz').exists():
root = root / 'data.jsonl.gz'
else:
raise ValueError("Don't know where to read metadata from in the dir. "
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
meta = load_audio_meta(root)
return cls(meta, **kwargs)
@classmethod
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
Args:
root (str or Path): Path to root folder containing audio files.
minimal_meta (bool): Whether to only load minimal metadata or not.
exts (list of str): Extensions for audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
if root.is_file():
meta = load_audio_meta(root, resolve=True)
else:
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
return cls(meta, **kwargs)
def main():
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
parser = argparse.ArgumentParser(
prog='audio_dataset',
description='Generate .jsonl files by scanning a folder.')
parser.add_argument('root', help='Root folder with all the audio files')
parser.add_argument('output_meta_file',
help='Output file to store the metadata, ')
parser.add_argument('--complete',
action='store_false', dest='minimal', default=True,
help='Retrieve all metadata, even the one that are expansive '
'to compute (e.g. normalization).')
parser.add_argument('--resolve',
action='store_true', default=False,
help='Resolve the paths to be absolute and with no symlinks.')
parser.add_argument('--workers',
default=10, type=int,
help='Number of workers.')
args = parser.parse_args()
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
save_audio_meta(args.output_meta_file, meta)
if __name__ == '__main__':
main()
================================================
FILE: audiocraft/data/audio_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Various utilities for audio convertion (pcm format, sample rate and channels),
and volume normalization."""
import io
import logging
import re
import sys
import typing as tp
import julius
import torch
import torchaudio
logger = logging.getLogger(__name__)
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
"""Convert audio to the given number of channels.
Args:
wav (torch.Tensor): Audio wave of shape [B, C, T].
channels (int): Expected number of channels as output.
Returns:
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
"""
*shape, src_channels, length = wav.shape
if src_channels == channels:
pass
elif channels == 1:
# Case 1:
# The caller asked 1-channel audio, and the stream has multiple
# channels, downmix all channels.
wav = wav.mean(dim=-2, keepdim=True)
elif src_channels == 1:
# Case 2:
# The caller asked for multiple channels, but the input file has
# a single channel, replicate the audio over all channels.
wav = wav.expand(*shape, channels, length)
elif src_channels >= channels:
# Case 3:
# The caller asked for multiple channels, and the input file has
# more channels than requested. In that case return the first channels.
wav = wav[..., :channels, :]
else:
# Case 4: What is a reasonable choice here?
raise ValueError('The audio file has less channels than requested but is not mono.')
return wav
def convert_audio(wav: torch.Tensor, from_rate: float,
to_rate: float, to_channels: int) -> torch.Tensor:
"""Convert audio to new sample rate and number of audio channels."""
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
wav = convert_audio_channels(wav, to_channels)
return wav
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, energy_floor: float = 2e-3):
"""Normalize an input signal to a user loudness in dB LKFS.
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
Args:
wav (torch.Tensor): Input multichannel audio data.
sample_rate (int): Sample rate.
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
loudness_compressor (bool): Uses tanh for soft clipping.
energy_floor (float): anything below that RMS level will not be rescaled.
Returns:
torch.Tensor: Loudness normalized output data.
"""
energy = wav.pow(2).mean().sqrt().item()
if energy < energy_floor:
return wav
transform = torchaudio.transforms.Loudness(sample_rate)
input_loudness_db = transform(wav).item()
# calculate the gain needed to scale to the desired loudness level
delta_loudness = -loudness_headroom_db - input_loudness_db
gain = 10.0 ** (delta_loudness / 20.0)
output = gain * wav
if loudness_compressor:
output = torch.tanh(output)
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
return output
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
"""
Utility function to clip the audio with logging if specified.
"""
max_scale = wav.abs().max()
if log_clipping and max_scale > 1:
clamp_prob = (wav.abs() > 1).float().mean().item()
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
wav.clamp_(-1, 1)
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, log_clipping: bool = False,
sample_rate: tp.Optional[int] = None,
stem_name: tp.Optional[str] = None) -> torch.Tensor:
"""Normalize the audio according to the prescribed strategy (see after).
Args:
wav (torch.Tensor): Audio data.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): If True, uses tanh based soft clipping.
log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
sample_rate (int): Sample rate for the audio data (required for loudness).
stem_name (str, optional): Stem name for clipping logging.
Returns:
torch.Tensor: Normalized audio.
"""
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
scale_rms = 10 ** (-rms_headroom_db / 20)
if strategy == 'peak':
rescaling = (scale_peak / wav.abs().max())
if normalize or rescaling < 1:
wav = wav * rescaling
elif strategy == 'clip':
wav = wav.clamp(-scale_peak, scale_peak)
elif strategy == 'rms':
mono = wav.mean(dim=0)
rescaling = scale_rms / mono.pow(2).mean().sqrt()
if normalize or rescaling < 1:
wav = wav * rescaling
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
elif strategy == 'loudness':
assert sample_rate is not None, "Loudness normalization requires sample rate."
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
else:
assert wav.abs().max() < 1
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
return wav
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""
Convert audio to float 32 bits PCM format.
Args:
wav (torch.tensor): Input wav tensor
Returns:
same wav in float32 PCM format
"""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / 2**15
elif wav.dtype == torch.int32:
return wav.float() / 2**31
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to int 16 bits PCM format.
..Warning:: There exist many formula for doing this conversion. None are perfect
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
it is possible that `i16_pcm(f32_pcm)) != Identity`.
Args:
wav (torch.tensor): Input wav tensor
Returns:
same wav in float16 PCM format
"""
if wav.dtype.is_floating_point:
assert wav.abs().max() <= 1
candidate = (wav * 2 ** 15).round()
if candidate.max() >= 2 ** 15: # clipping would occur
candidate = (wav * (2 ** 15 - 1)).round()
return candidate.short()
else:
assert wav.dtype == torch.int16
return wav
def compress(wav: torch.Tensor, sr: int,
target_format: tp.Literal["mp3", "ogg", "flac"] = "mp3",
bitrate: str = "128k") -> tp.Tuple[torch.Tensor, int]:
"""Convert audio wave form to a specified lossy format: mp3, ogg, flac
Args:
wav (torch.Tensor): Input wav tensor.
sr (int): Sampling rate.
target_format (str): Compression format (e.g., 'mp3').
bitrate (str): Bitrate for compression.
Returns:
Tuple of compressed WAV tensor and sampling rate.
"""
# Extract the bit rate from string (e.g., '128k')
match = re.search(r"\d+(\.\d+)?", str(bitrate))
parsed_bitrate = float(match.group()) if match else None
assert parsed_bitrate, f"Invalid bitrate specified (got {parsed_bitrate})"
try:
# Create a virtual file instead of saving to disk
buffer = io.BytesIO()
torchaudio.save(
buffer, wav, sr, format=target_format, bits_per_sample=parsed_bitrate,
)
# Move to the beginning of the file
buffer.seek(0)
compressed_wav, sr = torchaudio.load(buffer)
return compressed_wav, sr
except RuntimeError:
logger.warning(
f"compression failed skipping compression: {format} {parsed_bitrate}"
)
return wav, sr
def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") -> torch.Tensor:
"""Convert a batch of audio files to MP3 format, maintaining the original shape.
This function takes a batch of audio files represented as a PyTorch tensor, converts
them to MP3 format using the specified bitrate, and returns the batch in the same
shape as the input.
Args:
wav_tensor (torch.Tensor): Batch of audio files represented as a tensor.
Shape should be (batch_size, channels, length).
sr (int): Sampling rate of the audio.
bitrate (str): Bitrate for MP3 conversion, default is '128k'.
Returns:
torch.Tensor: Batch of audio files converted to MP3 format, with the same
shape as the input tensor.
"""
device = wav_tensor.device
batch_size, channels, original_length = wav_tensor.shape
# Flatten tensor for conversion and move to CPU
wav_tensor_flat = wav_tensor.view(1, -1).cpu()
# Convert to MP3 format with specified bitrate
wav_tensor_flat, _ = compress(wav_tensor_flat, sr, bitrate=bitrate)
# Reshape back to original batch format and trim or pad if necessary
wav_tensor = wav_tensor_flat.view(batch_size, channels, -1)
compressed_length = wav_tensor.shape[-1]
if compressed_length > original_length:
wav_tensor = wav_tensor[:, :, :original_length] # Trim excess frames
elif compressed_length < original_length:
padding = torch.zeros(
batch_size, channels, original_length - compressed_length, device=device
)
wav_tensor = torch.cat((wav_tensor, padding), dim=-1) # Pad with zeros
# Move tensor back to the original device
return wav_tensor.to(device)
def get_aac(
wav_tensor: torch.Tensor,
sr: int,
bitrate: str = "128k",
lowpass_freq: tp.Optional[int] = None,
) -> torch.Tensor:
"""Converts a batch of audio tensors to AAC format and then back to tensors.
This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert
these WAV files to AAC format. Finally, it loads the AAC files back into tensors.
Args:
wav_tensor (torch.Tensor): A batch of audio files represented as a tensor.
Shape should be (batch_size, channels, length).
sr (int): Sampling rate of the audio.
bitrate (str): Bitrate for AAC conversion, default is '128k'.
lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied.
Returns:
torch.Tensor: Batch of audio files converted to AAC and back, with the same
shape as the input tensor.
"""
import tempfile
import subprocess
device = wav_tensor.device
batch_size, channels, original_length = wav_tensor.shape
# Parse the bitrate value from the string
match = re.search(r"\d+(\.\d+)?", bitrate)
parsed_bitrate = (
match.group() if match else "128"
) # Default to 128 if parsing fails
# Flatten tensor for conversion and move to CPU
wav_tensor_flat = wav_tensor.view(1, -1).cpu()
with tempfile.NamedTemporaryFile(
suffix=".wav"
) as f_in, tempfile.NamedTemporaryFile(suffix=".aac") as f_out:
input_path, output_path = f_in.name, f_out.name
# Save the tensor as a WAV file
torchaudio.save(input_path, wav_tensor_flat, sr, backend="ffmpeg")
# Prepare FFmpeg command for AAC conversion
command = [
"ffmpeg",
"-y",
"-i",
input_path,
"-ar",
str(sr),
"-b:a",
f"{parsed_bitrate}k",
"-c:a",
"aac",
]
if lowpass_freq is not None:
command += ["-cutoff", str(lowpass_freq)]
command.append(output_path)
try:
# Run FFmpeg and suppress output
subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Load the AAC audio back into a tensor
aac_tensor, _ = torchaudio.load(output_path, backend="ffmpeg")
except Exception as exc:
raise RuntimeError(
"Failed to run command " ".join(command)} "
"(Often this means ffmpeg is not installed or the encoder is not supported, "
"make sure you installed an older version ffmpeg<5)"
) from exc
original_length_flat = batch_size * channels * original_length
compressed_length_flat = aac_tensor.shape[-1]
# Trim excess frames
if compressed_length_flat > original_length_flat:
aac_tensor = aac_tensor[:, :original_length_flat]
# Pad the shortedn frames
elif compressed_length_flat < original_length_flat:
padding = torch.zeros(
1, original_length_flat - compressed_length_flat, device=device
)
aac_tensor = torch.cat((aac_tensor, padding), dim=-1)
# Reshape and adjust length to match original tensor
wav_tensor = aac_tensor.view(batch_size, channels, -1)
compressed_length = wav_tensor.shape[-1]
assert compressed_length == original_length, (
"AAC-compressed audio does not have the same frames as original one. "
"One reason can be ffmpeg is not installed and used as proper backed "
"for torchaudio, or the AAC encoder is not correct. Run "
"`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for"
"AAC in the output."
)
return wav_tensor.to(device)
================================================
FILE: audiocraft/data/info_audio_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base classes for the datasets that also provide non-audio metadata,
e.g. description, text transcription etc.
"""
from dataclasses import dataclass
import logging
import math
import re
import typing as tp
import torch
from .audio_dataset import AudioDataset, AudioMeta
from ..environment import AudioCraftEnvironment
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
logger = logging.getLogger(__name__)
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
"""Monkey-patch meta to match cluster specificities."""
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
if meta.info_path is not None:
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
return meta
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
"""Monkey-patch all meta to match cluster specificities."""
return [_clusterify_meta(m) for m in meta]
@dataclass
class AudioInfo(SegmentWithAttributes):
"""Dummy SegmentInfo with empty attributes.
The InfoAudioDataset is expected to return metadata that inherits
from SegmentWithAttributes class and can return conditioning attributes.
This basically guarantees all datasets will be compatible with current
solver that contain conditioners requiring this.
"""
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
def to_condition_attributes(self) -> ConditioningAttributes:
return ConditioningAttributes()
class InfoAudioDataset(AudioDataset):
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
"""
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
super().__init__(clusterify_all_meta(meta), **kwargs)
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
if not self.return_info:
wav = super().__getitem__(index)
assert isinstance(wav, torch.Tensor)
return wav
wav, meta = super().__getitem__(index)
return wav, AudioInfo(**meta.to_dict())
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
"""Preprocess a single keyword or possible a list of keywords."""
if isinstance(value, list):
return get_keyword_list(value)
else:
return get_keyword(value)
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess a single keyword."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
else:
return value.strip()
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess a single keyword."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
else:
return value.strip().lower()
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
"""Preprocess a list of keywords."""
if isinstance(values, str):
values = [v.strip() for v in re.split(r'[,\s]', values)]
elif isinstance(values, float) and math.isnan(values):
values = []
if not isinstance(values, list):
logger.debug(f"Unexpected keyword list {values}")
values = [str(values)]
kws = [get_keyword(v) for v in values]
kw_list = [k for k in kws if k is not None]
if len(kw_list) == 0:
return None
else:
return kw_list
================================================
FILE: audiocraft/data/jasco_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import pickle
import math
import os
import torch
import typing as tp
from pathlib import Path
from dataclasses import dataclass, fields
from ..utils.utils import construct_frame_chords
from .music_dataset import MusicDataset, MusicInfo
from .audio_dataset import load_audio_meta
from ..modules.conditioners import (ConditioningAttributes, SymbolicCondition)
import librosa
import numpy as np
@dataclass
class JascoInfo(MusicInfo):
"""
A data class extending MusicInfo for JASCO. The following attributes are added:
Attributes:
frame_chords (Optional[list]): A list of chords associated with frames in the music piece.
"""
chords: tp.Optional[SymbolicCondition] = None
melody: tp.Optional[SymbolicCondition] = None
def to_condition_attributes(self) -> ConditioningAttributes:
out = ConditioningAttributes()
for _field in fields(self):
key, value = _field.name, getattr(self, _field.name)
if key == 'self_wav':
out.wav[key] = value
elif key in {'chords', 'melody'}:
out.symbolic[key] = value
elif key == 'joint_embed':
for embed_attribute, embed_cond in value.items():
out.joint_embed[embed_attribute] = embed_cond
else:
if isinstance(value, list):
value = ' '.join(value)
out.text[key] = value
return out
class MelodyData:
SALIENCE_MODEL_EXPECTED_SAMPLE_RATE = 22050
SALIENCE_MODEL_EXPECTED_HOP_SIZE = 256
def __init__(self,
latent_fr: int,
segment_duration: float,
melody_fr: int = 86,
melody_salience_dim: int = 53,
chroma_root: tp.Optional[str] = None,
override_cache: bool = False,
do_argmax: bool = True):
"""Module to load salience matrix for a given info.
Args:
latent_fr (int): latent frame rate to match (interpolates model frame rate accordingly).
segment_duration (float): expected segment duration.
melody_fr (int, optional): extracted salience frame rate. Defaults to 86.
melody_salience_dim (int, optional): salience dim. Defaults to 53.
chroma_root (str, optional): path to root containing salience cache. Defaults to None.
override_cache (bool, optional): rewrite cache. Defaults to False.
do_argmax (bool, optional): argmax the melody matrix. Defaults to True.
"""
self.segment_duration = segment_duration
self.melody_fr = melody_fr
self.latent_fr = latent_fr
self.melody_salience_dim = melody_salience_dim
self.do_argmax = do_argmax
self.tgt_chunk_len = int(latent_fr * segment_duration)
self.null_op = False
if chroma_root is None:
self.null_op = True
elif not os.path.exists(f"{chroma_root}/cache.pkl") or override_cache:
self.tracks = []
for file in librosa.util.find_files(chroma_root, ext='txt'):
with open(file, 'r') as f:
lines = f.readlines()
for line in lines:
self.tracks.append(line.strip())
# go over tracks and add the corresponding saliency file to self.saliency_files
self.saliency_files = []
for track in self.tracks:
# saliency file name
salience_file = f"{chroma_root}/{track.split('/')[-1].split('.')[0]}_multif0_salience.npz"
assert os.path.exists(salience_file), f"File {salience_file} does not exist"
self.saliency_files.append(salience_file)
self.trk2idx = {trk.split('/')[-1].split('.')[0]: i for i, trk in enumerate(self.tracks)}
torch.save({'tracks': self.tracks,
'saliency_files': self.saliency_files,
'trk2idx': self.trk2idx}, f"{chroma_root}/cache.pkl")
else:
tmp = torch.load(f"{chroma_root}/cache.pkl")
self.tracks = tmp['tracks']
self.saliency_files = tmp['saliency_files']
self.trk2idx = tmp['trk2idx']
self.model_frame_rate = int(self.SALIENCE_MODEL_EXPECTED_SAMPLE_RATE / self.SALIENCE_MODEL_EXPECTED_HOP_SIZE)
def load_saliency_from_saliency_dict(self,
saliency_dict: tp.Dict[str, tp.Any],
offset: float) -> torch.Tensor:
"""
construct the salience matrix and perform linear interpolation w.r.t the temporal axis to match the expected
frame rate.
"""
# get saliency map for the segment
saliency_dict_ = {}
l, r = int(offset * self.model_frame_rate), int((offset + self.segment_duration) * self.model_frame_rate)
saliency_dict_['salience'] = saliency_dict['salience'][:, l: r].T
saliency_dict_['times'] = saliency_dict['times'][l: r] - offset
saliency_dict_['freqs'] = saliency_dict['freqs']
saliency_dict_['salience'] = torch.Tensor(saliency_dict_['salience']).float().permute(1, 0) # C, T
if saliency_dict_['salience'].shape[-1] <= int(self.model_frame_rate) / self.latent_fr: # empty chroma
saliency_dict_['salience'] = torch.zeros((saliency_dict_['salience'].shape[0], self.tgt_chunk_len))
else:
salience = torch.nn.functional.interpolate(saliency_dict_['salience'].unsqueeze(0),
scale_factor=self.latent_fr/int(self.model_frame_rate),
mode='linear').squeeze(0)
if salience.shape[-1] < self.tgt_chunk_len:
salience = torch.nn.functional.pad(salience,
(0, self.tgt_chunk_len - salience.shape[-1]),
mode='constant',
value=0)
elif salience.shape[-1] > self.tgt_chunk_len:
salience = salience[..., :self.tgt_chunk_len]
saliency_dict_['salience'] = salience
salience = saliency_dict_['salience']
if self.do_argmax:
binary_mask = torch.zeros_like(salience)
binary_mask[torch.argmax(salience, dim=0), torch.arange(salience.shape[-1])] = 1
binary_mask *= (salience != 0).float()
salience = binary_mask
return salience
def get_null_salience(self) -> torch.Tensor:
return torch.zeros((self.melody_salience_dim, self.tgt_chunk_len))
def __call__(self, x: MusicInfo) -> torch.Tensor:
"""Reads salience matrix from memory, shifted by seek time
Args:
x (MusicInfo): Music info of a single sample
Returns:
torch.Tensor: salience matrix matching the target info
"""
fname: str = x.meta.path.split("/")[-1].split(".")[0] if x.meta.path is not None else ""
if x.meta.path is None or x.meta.path == "" or fname not in self.trk2idx:
salience = self.get_null_salience()
else:
assert fname in self.trk2idx, f"Track {fname} not found in the cache"
idx = self.trk2idx[fname]
saliency_dict = np.load(self.saliency_files[idx], allow_pickle=True)
salience = self.load_saliency_from_saliency_dict(saliency_dict, x.seek_time)
return salience
class JascoDataset(MusicDataset):
"""JASCO dataset is a MusicDataset with jasco-related symbolic data (chords, melody).
Args:
chords_card (int): The cardinality of the chords, default is 194.
compression_model_framerate (int): The framerate for the compression model, default is 50.
See `audiocraft.data.info_audio_dataset.MusicDataset` for full initialization arguments.
"""
@classmethod
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
Args:
root (str or Path): Path to root folder containing audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
# a directory is given
if root.is_dir():
if (root / 'data.jsonl').exists():
meta_json = root / 'data.jsonl'
elif (root / 'data.jsonl.gz').exists():
meta_json = root / 'data.jsonl.gz'
else:
raise ValueError("Don't know where to read metadata from in the dir. "
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
# jsonl file was specified
else:
assert root.exists() and root.suffix == '.jsonl', \
"Either specified path not exist or it is not a jsonl format"
meta_json = root
root = root.parent
meta = load_audio_meta(meta_json)
kwargs['root'] = root
return cls(meta, **kwargs)
def __init__(self, *args,
chords_card: int = 194,
compression_model_framerate: float = 50.,
melody_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = {},
**kwargs):
"""Dataset class for text-to-music generation with temporal controls as in
(JASCO)[https://arxiv.org/pdf/2406.10970]
Args:
chords_card (int, optional): Number of chord ebeddings. Defaults to 194.
compression_model_framerate (float, optional): Expected frame rate of the resulted latent. Defaults to 50.
melody_kwargs (tp.Optional[tp.Dict[str, tp.Any]], optional): See MelodyData class. Defaults to {}.
"""
root = kwargs.pop('root')
super().__init__(*args, **kwargs)
chords_mapping_path = root / 'chord_to_index_mapping.pkl'
chords_path = root / 'chords_per_track.pkl'
self.mapping_dict = pickle.load(open(chords_mapping_path, "rb")) if \
os.path.exists(chords_mapping_path) else None
self.chords_per_track = pickle.load(open(chords_path, "rb")) if \
os.path.exists(chords_path) else None
self.compression_model_framerate = compression_model_framerate
self.null_chord_idx = chords_card
self.melody_module = MelodyData(**melody_kwargs) # type: ignore
def _get_relevant_sublist(self, chords, timestamp):
"""
Returns the sublist of chords within the specified timestamp and segment length.
Args:
chords (list): A sorted list of tuples containing (time changed, chord).
timestamp (float): The timestamp at which to start the sublist.
Returns:
list: A list of chords within the specified timestamp and segment length.
"""
end_time = timestamp + self.segment_duration
# Use binary search to find the starting index of the relevant sublist
start_index = bisect.bisect_left(chords, (timestamp,))
if start_index != 0:
prev_chord = chords[start_index - 1]
else:
prev_chord = (0.0, "N")
relevant_chords = []
for time_changed, chord in chords[start_index:]:
if time_changed >= end_time:
break
relevant_chords.append((time_changed, chord))
return relevant_chords, prev_chord
def _get_chords(self, music_info: MusicInfo, effective_segment_dur: float) -> torch.Tensor:
if self.chords_per_track is None:
# use null chord when there's no chords in dataset
seq_len = math.ceil(self.compression_model_framerate * effective_segment_dur)
return torch.ones(seq_len, dtype=int) * self.null_chord_idx # type: ignore
fr = self.compression_model_framerate
idx = music_info.meta.path.split("/")[-1].split(".")[0]
chords = self.chords_per_track[idx]
min_timestamp = music_info.seek_time
chords = [(item[1], item[0]) for item in chords]
chords, prev_chord = self._get_relevant_sublist(
chords, min_timestamp
)
iter_min_timestamp = int(min_timestamp * fr) + 1
frame_chords = construct_frame_chords(
iter_min_timestamp, chords, self.mapping_dict, prev_chord[1], # type: ignore
fr, self.segment_duration # type: ignore
)
return torch.tensor(frame_chords)
def __getitem__(self, index):
wav, music_info = super().__getitem__(index)
assert not wav.isinfinite().any(), f"inf detected in wav file: {music_info}"
wav = wav.float()
# downcast music info to jasco info
jasco_info = JascoInfo({k: v for k, v in music_info.__dict__.items()})
# get chords
effective_segment_dur = (wav.shape[-1] / self.sample_rate) if \
self.segment_duration is None else self.segment_duration
frame_chords = self._get_chords(music_info, effective_segment_dur)
jasco_info.chords = SymbolicCondition(frame_chords=frame_chords)
# get melody
jasco_info.melody = SymbolicCondition(melody=self.melody_module(music_info))
return wav, jasco_info
================================================
FILE: audiocraft/data/music_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset of music tracks with rich metadata.
"""
from dataclasses import dataclass, field, fields, replace
import gzip
import json
import logging
from pathlib import Path
import random
import typing as tp
import torch
from .info_audio_dataset import (
InfoAudioDataset,
AudioInfo,
get_keyword_list,
get_keyword,
get_string
)
from ..modules.conditioners import (
ConditioningAttributes,
JointEmbedCondition,
WavCondition,
)
from ..utils.utils import warn_once
logger = logging.getLogger(__name__)
@dataclass
class MusicInfo(AudioInfo):
"""Segment info augmented with music metadata.
"""
# music-specific metadata
title: tp.Optional[str] = None
artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
key: tp.Optional[str] = None
bpm: tp.Optional[float] = None
genre: tp.Optional[str] = None
moods: tp.Optional[list] = None
keywords: tp.Optional[list] = None
description: tp.Optional[str] = None
name: tp.Optional[str] = None
instrument: tp.Optional[str] = None
# original wav accompanying the metadata
self_wav: tp.Optional[WavCondition] = None
# dict mapping attributes names to tuple of wav, text and metadata
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
@property
def has_music_meta(self) -> bool:
return self.name is not None
def to_condition_attributes(self) -> ConditioningAttributes:
out = ConditioningAttributes()
for _field in fields(self):
key, value = _field.name, getattr(self, _field.name)
if key == 'self_wav':
out.wav[key] = value
elif key == 'joint_embed':
for embed_attribute, embed_cond in value.items():
out.joint_embed[embed_attribute] = embed_cond
else:
if isinstance(value, list):
value = ' '.join(value)
out.text[key] = value
return out
@staticmethod
def attribute_getter(attribute):
if attribute == 'bpm':
preprocess_func = get_bpm
elif attribute == 'key':
preprocess_func = get_musical_key
elif attribute in ['moods', 'keywords']:
preprocess_func = get_keyword_list
elif attribute in ['genre', 'name', 'instrument']:
preprocess_func = get_keyword
elif attribute in ['title', 'artist', 'description']:
preprocess_func = get_string
else:
preprocess_func = None
return preprocess_func
@classmethod
def from_dict(cls, dictionary: dict, fields_required: bool = False):
_dictionary: tp.Dict[str, tp.Any] = {}
# allow a subset of attributes to not be loaded from the dictionary
# these attributes may be populated later
post_init_attributes = ['self_wav', 'joint_embed']
optional_fields = ['keywords']
for _field in fields(cls):
if _field.name in post_init_attributes:
continue
elif _field.name not in dictionary:
if fields_required and _field.name not in optional_fields:
raise KeyError(f"Unexpected missing key: {_field.name}")
else:
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
value = dictionary[_field.name]
if preprocess_func:
value = preprocess_func(value)
_dictionary[_field.name] = value
return cls(**_dictionary)
def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
"""Augment MusicInfo description with additional metadata fields and potential dropout.
Additional textual attributes are added given probability 'merge_text_conditions_p' and
the original textual description is dropped from the augmented description given probability drop_desc_p.
Args:
music_info (MusicInfo): The music metadata to augment.
merge_text_p (float): Probability of merging additional metadata to the description.
If provided value is 0, then no merging is performed.
drop_desc_p (float): Probability of dropping the original description on text merge.
if provided value is 0, then no drop out is performed.
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
Returns:
MusicInfo: The MusicInfo with augmented textual description.
"""
def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
keep_field = random.uniform(0, 1) < drop_other_p
return valid_field_name and valid_field_value and keep_field
def process_value(v: tp.Any) -> str:
if isinstance(v, (int, float, str)):
return str(v)
if isinstance(v, list):
return ", ".join(v)
else:
raise ValueError(f"Unknown type for text value! ({type(v), v})")
description = music_info.description
metadata_text = ""
if random.uniform(0, 1) < merge_text_p:
meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
random.shuffle(meta_pairs)
metadata_text = ". ".join(meta_pairs)
description = description if not random.uniform(0, 1) < drop_desc_p else None
logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
if description is None:
description = metadata_text if len(metadata_text) > 1 else None
else:
description = ". ".join([description.rstrip('.'), metadata_text])
description = description.strip() if description else None
music_info = replace(music_info)
music_info.description = description
return music_info
class Paraphraser:
def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
self.paraphrase_p = paraphrase_p
open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
with open_fn(paraphrase_source, 'rb') as f: # type: ignore
self.paraphrase_source = json.loads(f.read())
logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
def sample_paraphrase(self, audio_path: str, description: str):
if random.random() >= self.paraphrase_p:
return description
info_path = Path(audio_path).with_suffix('.json')
if info_path not in self.paraphrase_source:
warn_once(logger, f"{info_path} not in paraphrase source!")
return description
new_desc = random.choice(self.paraphrase_source[info_path])
logger.debug(f"{description} -> {new_desc}")
return new_desc
class MusicDataset(InfoAudioDataset):
"""Music dataset is an AudioDataset with music-related metadata.
Args:
info_fields_required (bool): Whether to enforce having required fields.
merge_text_p (float): Probability of merging additional metadata to the description.
drop_desc_p (float): Probability of dropping the original description on text merge.
drop_other_p (float): Probability of dropping the other fields used for text augmentation.
joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
paraphrases for the description. The json should be a dict with keys are the
original info path (e.g. track_path.json) and each value is a list of possible
paraphrased.
paraphrase_p (float): probability of taking a paraphrase.
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
"""
def __init__(self, *args, info_fields_required: bool = True,
merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
joint_embed_attributes: tp.List[str] = [],
paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
**kwargs):
kwargs['return_info'] = True # We require the info for each song of the dataset.
super().__init__(*args, **kwargs)
self.info_fields_required = info_fields_required
self.merge_text_p = merge_text_p
self.drop_desc_p = drop_desc_p
self.drop_other_p = drop_other_p
self.joint_embed_attributes = joint_embed_attributes
self.paraphraser = None
if paraphrase_source is not None:
self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
def __getitem__(self, index):
wav, info = super().__getitem__(index)
info_data = info.to_dict()
music_info_path = Path(info.meta.path).with_suffix('.json')
if Path(music_info_path).exists():
with open(music_info_path, 'r') as json_file:
music_data = json.load(json_file)
music_data.update(info_data)
music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
if self.paraphraser is not None:
music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
if self.merge_text_p:
music_info = augment_music_info_description(
music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
else:
music_info = MusicInfo.from_dict(info_data, fields_required=False)
music_info.self_wav = WavCondition(
wav=wav[None], length=torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
for att in self.joint_embed_attributes:
att_value = getattr(music_info, att)
joint_embed_cond = JointEmbedCondition(
wav[None], [att_value], torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
music_info.joint_embed[att] = joint_embed_cond
return wav, music_info
def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
"""Preprocess key keywords, discarding them if there are multiple key defined."""
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
return None
elif ',' in value:
# For now, we discard when multiple keys are defined separated with comas
return None
else:
return value.strip().lower()
def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
"""Preprocess to a float."""
if value is None:
return None
try:
return float(value)
except ValueError:
return None
================================================
FILE: audiocraft/data/sound_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dataset of audio with a simple description.
"""
from dataclasses import dataclass, fields, replace
import json
from pathlib import Path
import random
import typing as tp
import numpy as np
import torch
from .info_audio_dataset import (
InfoAudioDataset,
get_keyword_or_keyword_list
)
from ..modules.conditioners import (
ConditioningAttributes,
SegmentWithAttributes,
WavCondition,
)
EPS = torch.finfo(torch.float32).eps
TARGET_LEVEL_LOWER = -35
TARGET_LEVEL_UPPER = -15
@dataclass
class SoundInfo(SegmentWithAttributes):
"""Segment info augmented with Sound metadata.
"""
description: tp.Optional[str] = None
self_wav: tp.Optional[torch.Tensor] = None
@property
def has_sound_meta(self) -> bool:
return self.description is not None
def to_condition_attributes(self) -> ConditioningAttributes:
out = ConditioningAttributes()
for _field in fields(self):
key, value = _field.name, getattr(self, _field.name)
if key == 'self_wav':
out.wav[key] = value
else:
out.text[key] = value
return out
@staticmethod
def attribute_getter(attribute):
if attribute == 'description':
preprocess_func = get_keyword_or_keyword_list
else:
preprocess_func = None
return preprocess_func
@classmethod
def from_dict(cls, dictionary: dict, fields_required: bool = False):
_dictionary: tp.Dict[str, tp.Any] = {}
# allow a subset of attributes to not be loaded from the dictionary
# these attributes may be populated later
post_init_attributes = ['self_wav']
for _field in fields(cls):
if _field.name in post_init_attributes:
continue
elif _field.name not in dictionary:
if fields_required:
raise KeyError(f"Unexpected missing key: {_field.name}")
else:
preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
value = dictionary[_field.name]
if preprocess_func:
value = preprocess_func(value)
_dictionary[_field.name] = value
return cls(**_dictionary)
class SoundDataset(InfoAudioDataset):
"""Sound audio dataset: Audio dataset with environmental sound-specific metadata.
Args:
info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
The metadata files contained in this folder are expected to match the stem of the audio file with
a json extension.
aug_p (float): Probability of performing audio mixing augmentation on the batch.
mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
kwargs: Additional arguments for AudioDataset.
See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
"""
def __init__(
self,
*args,
info_fields_required: bool = True,
external_metadata_source: tp.Optional[str] = None,
aug_p: float = 0.,
mix_p: float = 0.,
mix_snr_low: int = -5,
mix_snr_high: int = 5,
mix_min_overlap: float = 0.5,
**kwargs
):
kwargs['return_info'] = True # We require the info for each song of the dataset.
super().__init__(*args, **kwargs)
self.info_fields_required = info_fields_required
self.external_metadata_source = external_metadata_source
self.aug_p = aug_p
self.mix_p = mix_p
if self.aug_p > 0:
assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
self.mix_snr_low = mix_snr_low
self.mix_snr_high = mix_snr_high
self.mix_min_overlap = mix_min_overlap
def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
"""Get path of JSON with metadata (description, etc.).
If there exists a JSON with the same name as 'path.name', then it will be used.
Else, such JSON will be searched for in an external json source folder if it exists.
"""
info_path = Path(path).with_suffix('.json')
if Path(info_path).exists():
return info_path
elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
return Path(self.external_metadata_source) / info_path.name
else:
raise Exception(f"Unable to find a metadata JSON for path: {path}")
def __getitem__(self, index):
wav, info = super().__getitem__(index)
info_data = info.to_dict()
info_path = self._get_info_path(info.meta.path)
if Path(info_path).exists():
with open(info_path, 'r') as json_file:
sound_data = json.load(json_file)
sound_data.update(info_data)
sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
# if there are multiple descriptions, sample one randomly
if isinstance(sound_info.description, list):
sound_info.description = random.choice(sound_info.description)
else:
sound_info = SoundInfo.from_dict(info_data, fields_required=False)
sound_info.self_wav = WavCondition(
wav=wav[None], length=torch.tensor([info.n_frames]),
sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
return wav, sound_info
def collater(self, samples):
# when training, audio mixing is performed in the collate function
wav, sound_info = super().collater(samples) # SoundDataset always returns infos
if self.aug_p > 0:
wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
min_overlap=self.mix_min_overlap)
return wav, sound_info
def rms_f(x: torch.Tensor) -> torch.Tensor:
return (x ** 2).mean(1).pow(0.5)
def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
"""Normalize the signal to the target level."""
rms = rms_f(audio)
scalar = 10 ** (target_level / 20) / (rms + EPS)
audio = audio * scalar.unsqueeze(1)
return audio
def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
return (abs(audio) > clipping_threshold).any(1)
def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
remainder = src.shape[1] - start
if dst.shape[1] > remainder:
src[:, start:] = src[:, start:] + dst[:, :remainder]
else:
src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
return src
def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
"""Function to mix clean speech and noise at various SNR levels.
Args:
clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
snr (int): SNR level when mixing.
min_overlap (float): Minimum overlap between the two mixed sources.
target_level (int): Gain level in dB.
clipping_threshold (float): Threshold for clipping the audio.
Returns:
torch.Tensor: The mixed audio, of shape [B, T].
"""
if clean.shape[1] > noise.shape[1]:
noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
else:
noise = noise[:, :clean.shape[1]]
# normalizing to -25 dB FS
clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
clean = normalize(clean, target_level)
rmsclean = rms_f(clean)
noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
noise = normalize(noise, target_level)
rmsnoise = rms_f(noise)
# set the noise level for a given SNR
noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
noisenewlevel = noise * noisescalar
# mix noise and clean speech
noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
# randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
# there is a chance of clipping that might happen with very less probability, which is not a major issue.
noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
rmsnoisy = rms_f(noisyspeech)
scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
noisyspeech = noisyspeech * scalarnoisy
clean = clean * scalarnoisy
noisenewlevel = noisenewlevel * scalarnoisy
# final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
clipped = is_clipped(noisyspeech)
if clipped.any():
noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
return noisyspeech
def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
if snr_low == snr_high:
snr = snr_low
else:
snr = np.random.randint(snr_low, snr_high)
mix = snr_mixer(src, dst, snr, min_overlap)
return mix
def mix_text(src_text: str, dst_text: str):
"""Mix text from different sources by concatenating them."""
if src_text == dst_text:
return src_text
return src_text + " " + dst_text
def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
snr_low: int, snr_high: int, min_overlap: float):
"""Mix samples within a batch, summing the waveforms and concatenating the text infos.
Args:
wavs (torch.Tensor): Audio tensors of shape [B, C, T].
infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
aug_p (float): Augmentation probability.
mix_p (float): Proportion of items in the batch to mix (and merge) together.
snr_low (int): Lowerbound for sampling SNR.
snr_high (int): Upperbound for sampling SNR.
min_overlap (float): Minimum overlap between mixed samples.
Returns:
tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
and mixed SoundInfo for the given batch.
"""
# no mixing to perform within the batch
if mix_p == 0:
return wavs, infos
if random.uniform(0, 1) < aug_p:
# perform all augmentations on waveforms as [B, T]
# randomly picking pairs of audio to mix
assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
wavs = wavs.mean(dim=1, keepdim=False)
B, T = wavs.shape
k = int(mix_p * B)
mixed_sources_idx = torch.randperm(B)[:k]
mixed_targets_idx = torch.randperm(B)[:k]
aug_wavs = snr_mix(
wavs[mixed_sources_idx],
wavs[mixed_targets_idx],
snr_low,
snr_high,
min_overlap,
)
# mixing textual descriptions in metadata
descriptions = [info.description for info in infos]
aug_infos = []
for i, j in zip(mixed_sources_idx, mixed_targets_idx):
text = mix_text(descriptions[i], descriptions[j])
m = replace(infos[i])
m.description = text
aug_infos.append(m)
# back to [B, C, T]
aug_wavs = aug_wavs.unsqueeze(1)
assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
return aug_wavs, aug_infos # [B, C, T]
else:
# randomly pick samples in the batch to match
# the batch size when performing audio mixing
B, C, T = wavs.shape
k = int(mix_p * B)
wav_idx = torch.randperm(B)[:k]
wavs = wavs[wav_idx]
infos = [infos[i] for i in wav_idx]
assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
return wavs, infos # [B, C, T]
================================================
FILE: audiocraft/data/zip.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utility for reading some info from inside a zip file.
"""
import typing
import zipfile
from dataclasses import dataclass
from functools import lru_cache
from typing_extensions import Literal
DEFAULT_SIZE = 32
MODE = Literal['r', 'w', 'x', 'a']
@dataclass(order=True)
class PathInZip:
"""Hold a path of file within a zip file.
Args:
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
Let's assume there is a zip file /some/location/foo.zip
and inside of it is a json file located at /data/file1.json,
Then we expect path = "/some/location/foo.zip:/data/file1.json".
"""
INFO_PATH_SEP = ':'
zip_path: str
file_path: str
def __init__(self, path: str) -> None:
split_path = path.split(self.INFO_PATH_SEP)
assert len(split_path) == 2
self.zip_path, self.file_path = split_path
@classmethod
def from_paths(cls, zip_path: str, file_path: str):
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
def __str__(self) -> str:
return self.zip_path + self.INFO_PATH_SEP + self.file_path
def _open_zip(path: str, mode: MODE = 'r'):
return zipfile.ZipFile(path, mode)
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
def set_zip_cache_size(max_size: int):
"""Sets the maximal LRU caching for zip file opening.
Args:
max_size (int): the maximal LRU cache.
"""
global _cached_open_zip
_cached_open_zip = lru_cache(max_size)(_open_zip)
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
"""Opens a file stored inside a zip and returns a file-like object.
Args:
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
mode (str): The mode in which to open the file with.
Returns:
A file-like object for PathInZip.
"""
zf = _cached_open_zip(path_in_zip.zip_path)
return zf.open(path_in_zip.file_path)
================================================
FILE: audiocraft/environment.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
"""
import logging
import os
from pathlib import Path
import re
import typing as tp
import omegaconf
from .utils.cluster import _guess_cluster_type
logger = logging.getLogger(__name__)
class AudioCraftEnvironment:
"""Environment configuration for teams and clusters.
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
Use the following environment variables to specify the cluster, team or configuration:
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
cannot be inferred automatically.
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
If not set, configuration is read from config/teams.yaml.
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
Cluster configuration are shared across teams to match compute allocation,
specify your cluster configuration in the configuration file under a key mapping
your team name.
"""
_instance = None
DEFAULT_TEAM = "default"
def __init__(self) -> None:
"""Loads configuration."""
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
cluster_type = _guess_cluster_type()
cluster = os.getenv(
"AUDIOCRAFT_CLUSTER", cluster_type.value
)
logger.info("Detecting cluster type %s", cluster_type)
self.cluster: str = cluster
config_path = os.getenv(
"AUDIOCRAFT_CONFIG",
Path(__file__)
.parent.parent.joinpath("config/teams", self.team)
.with_suffix(".yaml"),
)
self.config = omegaconf.OmegaConf.load(config_path)
self._dataset_mappers = []
cluster_config = self._get_cluster_config()
if "dataset_mappers" in cluster_config:
for pattern, repl in cluster_config["dataset_mappers"].items():
regex = re.compile(pattern)
self._dataset_mappers.append((regex, repl))
def _get_cluster_config(self) -> omegaconf.DictConfig:
assert isinstance(self.config, omegaconf.DictConfig)
return self.config[self.cluster]
@classmethod
def instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls):
"""Clears the environment and forces a reload on next invocation."""
cls._instance = None
@classmethod
def get_team(cls) -> str:
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
If not defined, defaults to "labs".
"""
return cls.instance().team
@classmethod
def get_cluster(cls) -> str:
"""Gets the detected cluster.
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
"""
return cls.instance().cluster
@classmethod
def get_dora_dir(cls) -> Path:
"""Gets the path to the dora directory for the current team and cluster.
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
"""
cluster_config = cls.instance()._get_cluster_config()
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
logger.warning(f"Dora directory: {dora_dir}")
return Path(dora_dir)
@classmethod
def get_reference_dir(cls) -> Path:
"""Gets the path to the reference directory for the current team and cluster.
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
"""
cluster_config = cls.instance()._get_cluster_config()
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
@classmethod
def get_slurm_exclude(cls) -> tp.Optional[str]:
"""Get the list of nodes to exclude for that cluster."""
cluster_config = cls.instance()._get_cluster_config()
return cluster_config.get("slurm_exclude")
@classmethod
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
Args:
partition_types (list[str], optional): partition types to retrieve. Values must be
from ['global', 'team']. If not provided, the global partition is returned.
"""
if not partition_types:
partition_types = ["global"]
cluster_config = cls.instance()._get_cluster_config()
partitions = [
cluster_config["partitions"][partition_type]
for partition_type in partition_types
]
return ",".join(partitions)
@classmethod
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
"""Converts reference placeholder in path with configured reference dir to resolve paths.
Args:
path (str or Path): Path to resolve.
Returns:
Path: Resolved path.
"""
path = str(path)
if path.startswith("//reference"):
reference_dir = cls.get_reference_dir()
logger.warn(f"Reference directory: {reference_dir}")
assert (
reference_dir.exists() and reference_dir.is_dir()
), f"Reference directory does not exist: {reference_dir}."
path = re.sub("^//reference", str(reference_dir), path)
return Path(path)
@classmethod
def apply_dataset_mappers(cls, path: str) -> str:
"""Applies dataset mapping regex rules as defined in the configuration.
If no rules are defined, the path is returned as-is.
"""
instance = cls.instance()
for pattern, repl in instance._dataset_mappers:
path = pattern.sub(repl, path)
return path
================================================
FILE: audiocraft/grids/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Dora Grids."""
================================================
FILE: audiocraft/grids/_base_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import time
import typing as tp
from dora import Explorer
import treetable as tt
def get_sheep_ping(sheep) -> tp.Optional[str]:
"""Return the amount of time since the Sheep made some update
to its log. Returns a str using the relevant time unit."""
ping = None
if sheep.log is not None and sheep.log.exists():
delta = time.time() - sheep.log.stat().st_mtime
if delta > 3600 * 24:
ping = f'{delta / (3600 * 24):.1f}d'
elif delta > 3600:
ping = f'{delta / (3600):.1f}h'
elif delta > 60:
ping = f'{delta / 60:.1f}m'
else:
ping = f'{delta:.1f}s'
return ping
class BaseExplorer(ABC, Explorer):
"""Base explorer for AudioCraft grids.
All task specific solvers are expected to implement the `get_grid_metrics`
method to specify logic about metrics to display for a given task.
If additional stages are used, the child explorer must define how to handle
these new stages in the `process_history` and `process_sheep` methods.
"""
def stages(self):
return ["train", "valid", "evaluate"]
def get_grid_meta(self):
"""Returns the list of Meta information to display for each XP/job.
"""
return [
tt.leaf("index", align=">"),
tt.leaf("name", wrap=140),
tt.leaf("state"),
tt.leaf("sig", align=">"),
tt.leaf("sid", align="<"),
]
@abstractmethod
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table.
"""
...
def process_sheep(self, sheep, history):
train = {
"epoch": len(history),
}
parts = {"train": train}
for metrics in history:
for key, sub in metrics.items():
part = parts.get(key, {})
if 'duration' in sub:
# Convert to minutes for readability.
sub['duration'] = sub['duration'] / 60.
part.update(sub)
parts[key] = part
ping = get_sheep_ping(sheep)
if ping is not None:
for name in self.stages():
if name not in parts:
parts[name] = {}
# Add the ping to each part for convenience.
parts[name]['ping'] = ping
return parts
================================================
FILE: audiocraft/grids/audiogen/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""AudioGen grids."""
================================================
FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ..musicgen._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=64, partition=partitions)
launcher.bind_(solver='audiogen/audiogen_base_16khz')
# replace this by the desired environmental sound dataset
launcher.bind_(dset='internal/sounds_16khz')
fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
launcher.bind_(fsdp)
launcher(medium)
================================================
FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Evaluation with objective metrics for the pretrained AudioGen models.
This grid takes signature from the training grid and runs evaluation-only stage.
When running the grid for the first time, please use:
REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
Note that you need the proper metrics external libraries setup to use all
the objective metrics activated in this grid. Refer to the README for more information.
"""
import os
from ..musicgen._explorers import GenerationEvalExplorer
from ...environment import AudioCraftEnvironment
from ... import train
def eval(launcher, batch_size: int = 32):
opts = {
'dset': 'audio/audiocaps_16khz',
'solver/audiogen/evaluation': 'objective_eval',
'execute_only': 'evaluate',
'+dataset.evaluate.batch_size': batch_size,
'+metrics.fad.tf.batch_size': 32,
}
# binary for FAD computation: replace this path with your own path
metrics_opts = {
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
}
opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
opt2 = {'transformer_lm.two_step_cfg': True}
sub = launcher.bind(opts)
sub.bind_(metrics_opts)
# base objective metrics
sub(opt1, opt2)
@GenerationEvalExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=4, partition=partitions)
if 'REGEN' not in os.environ:
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
with launcher.job_array():
for sig in folder.iterdir():
if not sig.is_symlink():
continue
xp = train.main.get_xp_from_sig(sig.name)
launcher(xp.argv)
return
audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
eval(audiogen_base_medium, batch_size=128)
================================================
FILE: audiocraft/grids/compression/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""EnCodec grids."""
================================================
FILE: audiocraft/grids/compression/_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import treetable as tt
from .._base_explorers import BaseExplorer
class CompressionExplorer(BaseExplorer):
eval_metrics = ["sisnr", "visqol"]
def stages(self):
return ["train", "valid", "evaluate"]
def get_grid_meta(self):
"""Returns the list of Meta information to display for each XP/job.
"""
return [
tt.leaf("index", align=">"),
tt.leaf("name", wrap=140),
tt.leaf("state"),
tt.leaf("sig", align=">"),
]
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table.
"""
return [
tt.group(
"train",
[
tt.leaf("epoch"),
tt.leaf("bandwidth", ".2f"),
tt.leaf("adv", ".4f"),
tt.leaf("d_loss", ".4f"),
],
align=">",
),
tt.group(
"valid",
[
tt.leaf("bandwidth", ".2f"),
tt.leaf("adv", ".4f"),
tt.leaf("msspec", ".4f"),
tt.leaf("sisnr", ".2f"),
],
align=">",
),
tt.group(
"evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
),
]
================================================
FILE: audiocraft/grids/compression/debug.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid is a minimal example for debugging compression task
and how to override parameters directly in a grid.
Learn more about dora grids: https://github.com/facebookresearch/dora
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=2, partition=partitions)
launcher.bind_(solver='compression/debug')
with launcher.job_array():
# base debug task using config from solver=compression/debug
launcher()
# we can override parameters in the grid to launch additional xps
launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
================================================
FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=8, partition=partitions)
# use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
# AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
launcher.bind_(solver='compression/encodec_audiogen_16khz')
# replace this by the desired sound dataset
launcher.bind_(dset='internal/sounds_16khz')
# launch xp
launcher()
================================================
FILE: audiocraft/grids/compression/encodec_base_24khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid shows how to train a base causal EnCodec model at 24 kHz.
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=8, partition=partitions)
# base causal EnCodec trained on monophonic audio sampled at 24 kHz
launcher.bind_(solver='compression/encodec_base_24khz')
# replace this by the desired dataset
launcher.bind_(dset='audio/example')
# launch xp
launcher()
================================================
FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Grid search file, simply list all the exp you want in `explorer`.
Any new exp added there will be scheduled.
You can cancel and experiment by commenting its line.
This grid shows how to train a MusicGen EnCodec model at 32 kHz.
"""
from ._explorers import CompressionExplorer
from ...environment import AudioCraftEnvironment
@CompressionExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=8, partition=partitions)
# use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
# MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
launcher.bind_(solver='compression/encodec_musicgen_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')
# launch xp
launcher()
launcher({
'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
'label': 'visqol',
'evaluate.metrics.visqol': True
})
================================================
FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Training of the 4 diffusion models described in
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
(paper link).
"""
from ._explorers import DiffusionExplorer
@DiffusionExplorer
def explorer(launcher):
launcher.slurm_(gpus=4, partition='learnfair')
launcher.bind_({'solver': 'diffusion/default',
'dset': 'internal/music_10k_32khz'})
with launcher.job_array():
launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
================================================
FILE: audiocraft/grids/diffusion/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Diffusion grids."""
================================================
FILE: audiocraft/grids/diffusion/_explorers.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import treetable as tt
from .._base_explorers import BaseExplorer
class DiffusionExplorer(BaseExplorer):
eval_metrics = ["sisnr", "visqol"]
def stages(self):
return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
def get_grid_meta(self):
"""Returns the list of Meta information to display for each XP/job.
"""
return [
tt.leaf("index", align=">"),
tt.leaf("name", wrap=140),
tt.leaf("state"),
tt.leaf("sig", align=">"),
]
def get_grid_metrics(self):
"""Return the metrics that should be displayed in the tracking table.
"""
return [
tt.group(
"train",
[
tt.leaf("epoch"),
tt.leaf("loss", ".3%"),
],
align=">",
),
tt.group(
"valid",
[
tt.leaf("loss", ".3%"),
# tt.leaf("loss_0", ".3%"),
],
align=">",
),
tt.group(
"valid_ema",
[
tt.leaf("loss", ".3%"),
# tt.leaf("loss_0", ".3%"),
],
align=">",
),
tt.group(
"evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
tt.leaf("rvm_3", ".4f"), ], align=">"
),
tt.group(
"evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
tt.leaf("rvm_3", ".4f")], align=">"
),
]
================================================
FILE: audiocraft/grids/magnet/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""MAGNeT grids."""
================================================
FILE: audiocraft/grids/magnet/audio_magnet_16khz.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ..musicgen._explorers import LMExplorer
from ...environment import AudioCraftEnvironment
@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='magnet/audio_magnet_16khz')
# replace this by the desired environmental sound dataset
launcher.bind_(dset='internal/sounds_16khz')
fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
# Small model (300M)
launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind()
sub()
# Medium model (1.5B)
launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind()
sub(medium, fsdp)
================================================
FILE: audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Evaluation with objective metrics for the pretrained audio-MAGNeT models.
This grid takes signature from the training grid and runs evaluation-only stage.
When running the grid for the first time, please use:
REGEN=1 dora grid magnet.audio_magnet_pretrained_16khz_eval
and re-use the REGEN=1 option when the grid is changed to force regenerating it.
Note that you need the proper metrics external libraries setup to use all
the objective metrics activated in this grid. Refer to the README for more information.
"""
import os
from ..musicgen._explorers import GenerationEvalExplorer
from ...environment import AudioCraftEnvironment
from ... import train
def eval(launcher, batch_size: int = 32):
opts = {
'dset': 'audio/audiocaps_16khz',
'solver/audiogen/evaluation': 'objective_eval',
'execute_only': 'evaluate',
'+dataset.evaluate.batch_size': batch_size,
'+metrics.fad.tf.batch_size': 32,
}
# binary for FAD computation: replace this path with your own path
metrics_opts = {
'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
}
sub = launcher.bind(opts)
sub.bind_(metrics_opts)
# base objective metrics
sub()
@GenerationEvalExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=4, partition=partitions)
if 'REGEN' not in os.environ:
folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
with launcher.job_array():
for sig in folder.iterdir():
if not sig.is_symlink():
continue
xp = train.main.get_xp_from_sig(sig.name)
launcher(xp.argv)
return
with launcher.job_array():
audio_magnet = launcher.bind(solver="magnet/audio_magnet_16khz")
fsdp = {'autocast': False, 'fsdp.use': True}
# Small audio-MAGNeT model (300M)
audio_magnet_small = audio_magnet.bind({'continue_from': '//pretrained/facebook/audio-magnet-small'})
eval(audio_magnet_small, batch_size=128)
# Medium audio-MAGNeT model (1.5B)
audio_magnet_medium = audio_magnet.bind({'continue_from': '//pretrained/facebook/audio-magnet-medium'})
audio_magnet_medium.bind_({'model/lm/model_
gitextract_1bqvithb/
├── .github/
│ ├── actions/
│ │ └── audiocraft_build/
│ │ └── action.yml
│ └── workflows/
│ ├── audiocraft_docs.yml
│ ├── audiocraft_linter.yml
│ └── audiocraft_tests.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_weights
├── MANIFEST.in
├── Makefile
├── README.md
├── assets/
│ ├── chord_to_index_mapping.pkl
│ ├── salience_1.th
│ └── salience_2.th
├── audiocraft/
│ ├── __init__.py
│ ├── adversarial/
│ │ ├── __init__.py
│ │ ├── discriminators/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── mpd.py
│ │ │ ├── msd.py
│ │ │ └── msstftd.py
│ │ └── losses.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── audio.py
│ │ ├── audio_dataset.py
│ │ ├── audio_utils.py
│ │ ├── info_audio_dataset.py
│ │ ├── jasco_dataset.py
│ │ ├── music_dataset.py
│ │ ├── sound_dataset.py
│ │ └── zip.py
│ ├── environment.py
│ ├── grids/
│ │ ├── __init__.py
│ │ ├── _base_explorers.py
│ │ ├── audiogen/
│ │ │ ├── __init__.py
│ │ │ ├── audiogen_base_16khz.py
│ │ │ └── audiogen_pretrained_16khz_eval.py
│ │ ├── compression/
│ │ │ ├── __init__.py
│ │ │ ├── _explorers.py
│ │ │ ├── debug.py
│ │ │ ├── encodec_audiogen_16khz.py
│ │ │ ├── encodec_base_24khz.py
│ │ │ └── encodec_musicgen_32khz.py
│ │ ├── diffusion/
│ │ │ ├── 4_bands_base_32khz.py
│ │ │ ├── __init__.py
│ │ │ └── _explorers.py
│ │ ├── magnet/
│ │ │ ├── __init__.py
│ │ │ ├── audio_magnet_16khz.py
│ │ │ ├── audio_magnet_pretrained_16khz_eval.py
│ │ │ ├── magnet_32khz.py
│ │ │ └── magnet_pretrained_32khz_eval.py
│ │ ├── musicgen/
│ │ │ ├── __init__.py
│ │ │ ├── _explorers.py
│ │ │ ├── musicgen_base_32khz.py
│ │ │ ├── musicgen_base_cached_32khz.py
│ │ │ ├── musicgen_clapemb_32khz.py
│ │ │ ├── musicgen_melody_32khz.py
│ │ │ ├── musicgen_pretrained_32khz_eval.py
│ │ │ ├── musicgen_stereo_finetune_32khz.py
│ │ │ └── musicgen_style_32khz.py
│ │ └── watermarking/
│ │ ├── __init__.py
│ │ ├── _explorers.py
│ │ ├── audioseal.py
│ │ └── kbits.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── balancer.py
│ │ ├── loudnessloss.py
│ │ ├── sisnr.py
│ │ ├── specloss.py
│ │ ├── stftloss.py
│ │ └── wmloss.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── chroma_cosinesim.py
│ │ ├── clap_consistency.py
│ │ ├── fad.py
│ │ ├── kld.py
│ │ ├── miou.py
│ │ ├── pesq.py
│ │ ├── rvm.py
│ │ └── visqol.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── audiogen.py
│ │ ├── builders.py
│ │ ├── encodec.py
│ │ ├── flow_matching.py
│ │ ├── genmodel.py
│ │ ├── jasco.py
│ │ ├── lm.py
│ │ ├── lm_magnet.py
│ │ ├── loaders.py
│ │ ├── magnet.py
│ │ ├── multibanddiffusion.py
│ │ ├── musicgen.py
│ │ ├── unet.py
│ │ └── watermark.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── chroma.py
│ │ ├── codebooks_patterns.py
│ │ ├── conditioners.py
│ │ ├── conv.py
│ │ ├── diffusion_schedule.py
│ │ ├── jasco_conditioners.py
│ │ ├── lstm.py
│ │ ├── rope.py
│ │ ├── seanet.py
│ │ ├── streaming.py
│ │ ├── transformer.py
│ │ ├── unet_transformer.py
│ │ └── watermark.py
│ ├── optim/
│ │ ├── __init__.py
│ │ ├── cosine_lr_scheduler.py
│ │ ├── dadam.py
│ │ ├── ema.py
│ │ ├── fsdp.py
│ │ ├── inverse_sqrt_lr_scheduler.py
│ │ ├── linear_warmup_lr_scheduler.py
│ │ └── polynomial_decay_lr_scheduler.py
│ ├── py.typed
│ ├── quantization/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── core_vq.py
│ │ └── vq.py
│ ├── solvers/
│ │ ├── __init__.py
│ │ ├── audiogen.py
│ │ ├── base.py
│ │ ├── builders.py
│ │ ├── compression.py
│ │ ├── diffusion.py
│ │ ├── jasco.py
│ │ ├── magnet.py
│ │ ├── musicgen.py
│ │ └── watermark.py
│ ├── train.py
│ └── utils/
│ ├── __init__.py
│ ├── audio_effects.py
│ ├── autocast.py
│ ├── best_state.py
│ ├── cache.py
│ ├── checkpoint.py
│ ├── cluster.py
│ ├── deadlock.py
│ ├── export.py
│ ├── export_legacy.py
│ ├── notebook.py
│ ├── profiler.py
│ ├── samples/
│ │ ├── __init__.py
│ │ └── manager.py
│ └── utils.py
├── config/
│ ├── augmentations/
│ │ └── default.yaml
│ ├── conditioner/
│ │ ├── chords2music.yaml
│ │ ├── chroma2music.yaml
│ │ ├── clapemb2music.yaml
│ │ ├── drums2music.yaml
│ │ ├── jasco_chords_drums.yaml
│ │ ├── jasco_chords_drums_melody.yaml
│ │ ├── none.yaml
│ │ ├── style2music.yaml
│ │ ├── text2music.yaml
│ │ └── text2sound.yaml
│ ├── config.yaml
│ ├── dset/
│ │ ├── audio/
│ │ │ ├── audiocaps_16khz.yaml
│ │ │ ├── default.yaml
│ │ │ ├── example.yaml
│ │ │ └── musiccaps_32khz.yaml
│ │ ├── default.yaml
│ │ └── internal/
│ │ ├── music_10k_32khz.yaml
│ │ ├── music_400k_32khz.yaml
│ │ └── sounds_16khz.yaml
│ ├── model/
│ │ ├── encodec/
│ │ │ ├── default.yaml
│ │ │ ├── encodec_base_causal.yaml
│ │ │ ├── encodec_large_nq4_s320.yaml
│ │ │ └── encodec_large_nq4_s640.yaml
│ │ ├── lm/
│ │ │ ├── audiogen_lm.yaml
│ │ │ ├── default.yaml
│ │ │ ├── model_scale/
│ │ │ │ ├── base.yaml
│ │ │ │ ├── large.yaml
│ │ │ │ ├── medium.yaml
│ │ │ │ ├── small.yaml
│ │ │ │ └── xsmall.yaml
│ │ │ └── musicgen_lm.yaml
│ │ ├── none.yaml
│ │ ├── score/
│ │ │ └── basic.yaml
│ │ └── watermark/
│ │ └── default.yaml
│ ├── solver/
│ │ ├── audiogen/
│ │ │ ├── audiogen_base_16khz.yaml
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ └── evaluation/
│ │ │ ├── none.yaml
│ │ │ └── objective_eval.yaml
│ │ ├── compression/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ ├── encodec_audiogen_16khz.yaml
│ │ │ ├── encodec_base_24khz.yaml
│ │ │ └── encodec_musicgen_32khz.yaml
│ │ ├── default.yaml
│ │ ├── diffusion/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ └── encodec_24khz.yaml
│ │ ├── jasco/
│ │ │ ├── chords.yaml
│ │ │ ├── chords_drums.yaml
│ │ │ ├── chords_drums_melody.yaml
│ │ │ ├── drums.yaml
│ │ │ └── jasco_32khz_base.yaml
│ │ ├── magnet/
│ │ │ ├── audio_magnet_16khz.yaml
│ │ │ └── magnet_32khz.yaml
│ │ ├── musicgen/
│ │ │ ├── debug.yaml
│ │ │ ├── default.yaml
│ │ │ ├── evaluation/
│ │ │ │ ├── none.yaml
│ │ │ │ └── objective_eval.yaml
│ │ │ ├── musicgen_base_32khz.yaml
│ │ │ ├── musicgen_melody_32khz.yaml
│ │ │ └── musicgen_style_32khz.yaml
│ │ └── watermark/
│ │ ├── debug.yaml
│ │ ├── default.yaml
│ │ └── robustness.yaml
│ └── teams/
│ ├── default.yaml
│ └── labs.yaml
├── dataset/
│ └── example/
│ ├── electro_1.json
│ └── electro_2.json
├── demos/
│ ├── audiogen_demo.ipynb
│ ├── jasco_app.py
│ ├── jasco_demo.ipynb
│ ├── magnet_app.py
│ ├── magnet_demo.ipynb
│ ├── musicgen_app.py
│ ├── musicgen_demo.ipynb
│ ├── musicgen_style_app.py
│ └── musicgen_style_demo.ipynb
├── docs/
│ ├── AUDIOGEN.md
│ ├── CONDITIONING.md
│ ├── DATASETS.md
│ ├── ENCODEC.md
│ ├── JASCO.md
│ ├── MAGNET.md
│ ├── MBD.md
│ ├── METRICS.md
│ ├── MUSICGEN.md
│ ├── MUSICGEN_STYLE.md
│ ├── TRAINING.md
│ └── WATERMARKING.md
├── egs/
│ └── example/
│ └── data.jsonl
├── jasco_demo.ipynb
├── model_cards/
│ ├── AUDIOGEN_MODEL_CARD.md
│ ├── JASCO_MODEL_CARD.md
│ ├── MAGNET_MODEL_CARD.md
│ ├── MUSICGEN_MODEL_CARD.md
│ └── MUSICGEN_STYLE_MODEL_CARD.md
├── mypy.ini
├── requirements.txt
├── scripts/
│ ├── __init__.py
│ ├── chords/
│ │ ├── build_chord_maps.py
│ │ ├── extract_chords.py
│ │ └── job_array_example.sh
│ ├── mos.py
│ ├── resample_dataset.py
│ ├── static/
│ │ └── style.css
│ └── templates/
│ ├── base.html
│ ├── index.html
│ ├── login.html
│ ├── results.html
│ └── survey.html
├── setup.cfg
├── setup.py
└── tests/
├── __init__.py
├── adversarial/
│ ├── __init__.py
│ ├── test_discriminators.py
│ └── test_losses.py
├── common_utils/
│ ├── __init__.py
│ ├── temp_utils.py
│ └── wav_utils.py
├── data/
│ ├── __init__.py
│ ├── test_audio.py
│ ├── test_audio_dataset.py
│ └── test_audio_utils.py
├── losses/
│ ├── __init__.py
│ └── test_losses.py
├── metrics/
│ ├── __init__.py
│ └── test_pesq.py
├── models/
│ ├── test_audiogen.py
│ ├── test_encodec_model.py
│ ├── test_multibanddiffusion.py
│ ├── test_musicgen.py
│ └── test_watermark.py
├── modules/
│ ├── __init__.py
│ ├── test_activations.py
│ ├── test_codebooks_patterns.py
│ ├── test_conv.py
│ ├── test_lstm.py
│ ├── test_rope.py
│ ├── test_seanet.py
│ └── test_transformer.py
├── quantization/
│ └── test_vq.py
└── utils/
├── __init__.py
└── test_audio_effects.py
SYMBOL INDEX (1517 symbols across 145 files)
FILE: audiocraft/adversarial/discriminators/base.py
class MultiDiscriminator (line 19) | class MultiDiscriminator(ABC, nn.Module):
method __init__ (line 22) | def __init__(self):
method forward (line 26) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
method num_discriminators (line 31) | def num_discriminators(self) -> int:
FILE: audiocraft/adversarial/discriminators/mpd.py
function get_padding (line 17) | def get_padding(kernel_size: int, dilation: int = 1) -> int:
class PeriodDiscriminator (line 21) | class PeriodDiscriminator(nn.Module):
method __init__ (line 38) | def __init__(self, period: int, in_channels: int = 1, out_channels: in...
method forward (line 58) | def forward(self, x: torch.Tensor):
class MultiPeriodDiscriminator (line 79) | class MultiPeriodDiscriminator(MultiDiscriminator):
method __init__ (line 88) | def __init__(self, in_channels: int = 1, out_channels: int = 1,
method num_discriminators (line 96) | def num_discriminators(self):
method forward (line 99) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
FILE: audiocraft/adversarial/discriminators/msd.py
class ScaleDiscriminator (line 17) | class ScaleDiscriminator(nn.Module):
method __init__ (line 37) | def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Seq...
method forward (line 83) | def forward(self, x: torch.Tensor):
class MultiScaleDiscriminator (line 95) | class MultiScaleDiscriminator(MultiDiscriminator):
method __init__ (line 105) | def __init__(self, in_channels: int = 1, out_channels: int = 1, downsa...
method num_discriminators (line 114) | def num_discriminators(self):
method forward (line 117) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
FILE: audiocraft/adversarial/discriminators/msstftd.py
function get_2d_padding (line 18) | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[i...
class DiscriminatorSTFT (line 22) | class DiscriminatorSTFT(nn.Module):
method __init__ (line 41) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i...
method forward (line 81) | def forward(self, x: torch.Tensor):
class MultiScaleSTFTDiscriminator (line 94) | class MultiScaleSTFTDiscriminator(MultiDiscriminator):
method __init__ (line 107) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i...
method num_discriminators (line 120) | def num_discriminators(self):
method _separate_channels (line 123) | def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
method forward (line 127) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
FILE: audiocraft/adversarial/losses.py
class AdversarialLoss (line 26) | class AdversarialLoss(nn.Module):
method __init__ (line 49) | def __init__(self,
method _save_to_state_dict (line 67) | def _save_to_state_dict(self, destination, prefix, keep_vars):
method _load_from_state_dict (line 73) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
method get_adversary_pred (line 78) | def get_adversary_pred(self, x):
method train_adv (line 89) | def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.T...
method forward (line 115) | def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[...
function get_adv_criterion (line 138) | def get_adv_criterion(loss_type: str) -> tp.Callable:
function get_fake_criterion (line 149) | def get_fake_criterion(loss_type: str) -> tp.Callable:
function get_real_criterion (line 158) | def get_real_criterion(loss_type: str) -> tp.Callable:
function mse_real_loss (line 167) | def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
function mse_fake_loss (line 171) | def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
function hinge_real_loss (line 175) | def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
function hinge_fake_loss (line 179) | def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
function mse_loss (line 183) | def mse_loss(x: torch.Tensor) -> torch.Tensor:
function hinge_loss (line 189) | def hinge_loss(x: torch.Tensor) -> torch.Tensor:
function hinge2_loss (line 195) | def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
class FeatureMatchingLoss (line 201) | class FeatureMatchingLoss(nn.Module):
method __init__ (line 209) | def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: boo...
method forward (line 214) | def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List...
FILE: audiocraft/data/audio.py
function _init_av (line 31) | def _init_av():
class AudioFileInfo (line 41) | class AudioFileInfo:
function _av_info (line 47) | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
function _soundfile_info (line 57) | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
function audio_info (line 62) | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
function _av_read (line 72) | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, durati...
function audio_read (line 116) | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
function _piping_to_ffmpeg (line 147) | def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, ...
function audio_write (line 159) | def audio_write(stem_name: tp.Union[str, Path],
function get_spec (line 234) | def get_spec(y, sr=16000, n_fft=4096, hop_length=128, dur=8) -> np.ndarray:
function save_spectrograms (line 256) | def save_spectrograms(
FILE: audiocraft/data/audio_dataset.py
class BaseInfo (line 39) | class BaseInfo:
method _dict2fields (line 42) | def _dict2fields(cls, dictionary: dict):
method from_dict (line 49) | def from_dict(cls, dictionary: dict):
method to_dict (line 53) | def to_dict(self):
class AudioMeta (line 61) | class AudioMeta(BaseInfo):
method from_dict (line 71) | def from_dict(cls, dictionary: dict):
method to_dict (line 77) | def to_dict(self):
class SegmentInfo (line 85) | class SegmentInfo(BaseInfo):
function _get_audio_meta (line 101) | def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
function _resolve_audio_meta (line 118) | def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
function find_audio_files (line 145) | def find_audio_files(path: tp.Union[Path, str],
function load_audio_meta (line 204) | def load_audio_meta(path: tp.Union[str, Path],
function save_audio_meta (line 228) | def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
class AudioDataset (line 244) | class AudioDataset:
method __init__ (line 295) | def __init__(self,
method start_epoch (line 350) | def start_epoch(self, epoch: int):
method __len__ (line 353) | def __len__(self):
method _get_sampling_probabilities (line 356) | def _get_sampling_probabilities(self, normalized: bool = True):
method _get_file_permutation (line 373) | def _get_file_permutation(num_files: int, permutation_index: int, base...
method sample_file (line 380) | def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
method _audio_read (line 404) | def _audio_read(self, path: str, seek_time: float = 0, duration: float...
method __getitem__ (line 413) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t...
method collater (line 462) | def collater(self, samples):
method _filter_duration (line 502) | def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioM...
method from_meta (line 524) | def from_meta(cls, root: tp.Union[str, Path], **kwargs):
method from_path (line 544) | def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
function main (line 562) | def main():
FILE: audiocraft/data/audio_utils.py
function convert_audio_channels (line 21) | def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torc...
function convert_audio (line 54) | def convert_audio(wav: torch.Tensor, from_rate: float,
function normalize_loudness (line 62) | def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_hea...
function _clip_wav (line 91) | def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: ...
function normalize_audio (line 103) | def normalize_audio(wav: torch.Tensor, normalize: bool = True,
function f32_pcm (line 155) | def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
function i16_pcm (line 172) | def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
function compress (line 195) | def compress(wav: torch.Tensor, sr: int,
function get_mp3 (line 233) | def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") ->...
function get_aac (line 274) | def get_aac(
FILE: audiocraft/data/info_audio_dataset.py
function _clusterify_meta (line 25) | def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
function clusterify_all_meta (line 33) | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
class AudioInfo (line 39) | class AudioInfo(SegmentWithAttributes):
method to_condition_attributes (line 50) | def to_condition_attributes(self) -> ConditioningAttributes:
class InfoAudioDataset (line 54) | class InfoAudioDataset(AudioDataset):
method __init__ (line 59) | def __init__(self, meta: tp.List[AudioMeta], **kwargs):
method __getitem__ (line 62) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t...
function get_keyword_or_keyword_list (line 71) | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp....
function get_string (line 79) | def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
function get_keyword (line 87) | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
function get_keyword_list (line 95) | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional...
FILE: audiocraft/data/jasco_dataset.py
class JascoInfo (line 23) | class JascoInfo(MusicInfo):
method to_condition_attributes (line 32) | def to_condition_attributes(self) -> ConditioningAttributes:
class MelodyData (line 50) | class MelodyData:
method __init__ (line 55) | def __init__(self,
method load_saliency_from_saliency_dict (line 112) | def load_saliency_from_saliency_dict(self,
method get_null_salience (line 150) | def get_null_salience(self) -> torch.Tensor:
method __call__ (line 153) | def __call__(self, x: MusicInfo) -> torch.Tensor:
class JascoDataset (line 173) | class JascoDataset(MusicDataset):
method from_meta (line 183) | def from_meta(cls, root: tp.Union[str, Path], **kwargs):
method __init__ (line 210) | def __init__(self, *args,
method _get_relevant_sublist (line 239) | def _get_relevant_sublist(self, chords, timestamp):
method _get_chords (line 269) | def _get_chords(self, music_info: MusicInfo, effective_segment_dur: fl...
method __getitem__ (line 296) | def __getitem__(self, index):
FILE: audiocraft/data/music_dataset.py
class MusicInfo (line 37) | class MusicInfo(AudioInfo):
method has_music_meta (line 57) | def has_music_meta(self) -> bool:
method to_condition_attributes (line 60) | def to_condition_attributes(self) -> ConditioningAttributes:
method attribute_getter (line 76) | def attribute_getter(attribute):
method from_dict (line 92) | def from_dict(cls, dictionary: dict, fields_required: bool = False):
function augment_music_info_description (line 115) | def augment_music_info_description(music_info: MusicInfo, merge_text_p: ...
class Paraphraser (line 167) | class Paraphraser:
method __init__ (line 168) | def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_...
method sample_paraphrase (line 175) | def sample_paraphrase(self, audio_path: str, description: str):
class MusicDataset (line 187) | class MusicDataset(InfoAudioDataset):
method __init__ (line 204) | def __init__(self, *args, info_fields_required: bool = True,
method __getitem__ (line 220) | def __getitem__(self, index):
function get_musical_key (line 252) | def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
function get_bpm (line 263) | def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
FILE: audiocraft/data/sound_dataset.py
class SoundInfo (line 35) | class SoundInfo(SegmentWithAttributes):
method has_sound_meta (line 42) | def has_sound_meta(self) -> bool:
method to_condition_attributes (line 45) | def to_condition_attributes(self) -> ConditioningAttributes:
method attribute_getter (line 57) | def attribute_getter(attribute):
method from_dict (line 65) | def from_dict(cls, dictionary: dict, fields_required: bool = False):
class SoundDataset (line 87) | class SoundDataset(InfoAudioDataset):
method __init__ (line 104) | def __init__(
method _get_info_path (line 129) | def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
method __getitem__ (line 142) | def __getitem__(self, index):
method collater (line 163) | def collater(self, samples):
function rms_f (line 173) | def rms_f(x: torch.Tensor) -> torch.Tensor:
function normalize (line 177) | def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Ten...
function is_clipped (line 185) | def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) ->...
function mix_pair (line 189) | def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -...
function snr_mixer (line 199) | def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_ov...
function snr_mix (line 252) | def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high...
function mix_text (line 261) | def mix_text(src_text: str, dst_text: str):
function mix_samples (line 268) | def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: fl...
FILE: audiocraft/data/zip.py
class PathInZip (line 22) | class PathInZip:
method __init__ (line 36) | def __init__(self, path: str) -> None:
method from_paths (line 42) | def from_paths(cls, zip_path: str, file_path: str):
method __str__ (line 45) | def __str__(self) -> str:
function _open_zip (line 49) | def _open_zip(path: str, mode: MODE = 'r'):
function set_zip_cache_size (line 56) | def set_zip_cache_size(max_size: int):
function open_file_in_zip (line 66) | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
FILE: audiocraft/environment.py
class AudioCraftEnvironment (line 25) | class AudioCraftEnvironment:
method __init__ (line 49) | def __init__(self) -> None:
method _get_cluster_config (line 74) | def _get_cluster_config(self) -> omegaconf.DictConfig:
method instance (line 79) | def instance(cls):
method reset (line 85) | def reset(cls):
method get_team (line 90) | def get_team(cls) -> str:
method get_cluster (line 97) | def get_cluster(cls) -> str:
method get_dora_dir (line 104) | def get_dora_dir(cls) -> Path:
method get_reference_dir (line 114) | def get_reference_dir(cls) -> Path:
method get_slurm_exclude (line 122) | def get_slurm_exclude(cls) -> tp.Optional[str]:
method get_slurm_partitions (line 128) | def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str...
method resolve_reference_path (line 146) | def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
method apply_dataset_mappers (line 167) | def apply_dataset_mappers(cls, path: str) -> str:
FILE: audiocraft/grids/_base_explorers.py
function get_sheep_ping (line 14) | def get_sheep_ping(sheep) -> tp.Optional[str]:
class BaseExplorer (line 31) | class BaseExplorer(ABC, Explorer):
method stages (line 40) | def stages(self):
method get_grid_meta (line 43) | def get_grid_meta(self):
method get_grid_metrics (line 55) | def get_grid_metrics(self):
method process_sheep (line 60) | def process_sheep(self, sheep, history):
FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
function eval (line 26) | def eval(launcher, batch_size: int = 32):
function explorer (line 49) | def explorer(launcher):
FILE: audiocraft/grids/compression/_explorers.py
class CompressionExplorer (line 12) | class CompressionExplorer(BaseExplorer):
method stages (line 15) | def stages(self):
method get_grid_meta (line 18) | def get_grid_meta(self):
method get_grid_metrics (line 28) | def get_grid_metrics(self):
FILE: audiocraft/grids/compression/debug.py
function explorer (line 22) | def explorer(launcher):
FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py
function explorer (line 20) | def explorer(launcher):
FILE: audiocraft/grids/compression/encodec_base_24khz.py
function explorer (line 20) | def explorer(launcher):
FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py
function explorer (line 20) | def explorer(launcher):
FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py
function explorer (line 17) | def explorer(launcher):
FILE: audiocraft/grids/diffusion/_explorers.py
class DiffusionExplorer (line 12) | class DiffusionExplorer(BaseExplorer):
method stages (line 15) | def stages(self):
method get_grid_meta (line 18) | def get_grid_meta(self):
method get_grid_metrics (line 28) | def get_grid_metrics(self):
FILE: audiocraft/grids/magnet/audio_magnet_16khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py
function eval (line 26) | def eval(launcher, batch_size: int = 32):
function explorer (line 47) | def explorer(launcher):
FILE: audiocraft/grids/magnet/magnet_32khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py
function eval (line 26) | def eval(launcher, batch_size: int = 32):
function explorer (line 47) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/_explorers.py
class LMExplorer (line 14) | class LMExplorer(BaseExplorer):
method stages (line 17) | def stages(self) -> tp.List[str]:
method get_grid_metrics (line 20) | def get_grid_metrics(self):
method process_sheep (line 45) | def process_sheep(self, sheep, history):
class GenerationEvalExplorer (line 69) | class GenerationEvalExplorer(BaseExplorer):
method stages (line 72) | def stages(self) -> tp.List[str]:
method get_grid_metrics (line 75) | def get_grid_metrics(self):
FILE: audiocraft/grids/musicgen/musicgen_base_32khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/musicgen_melody_32khz.py
function explorer (line 12) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
function eval (line 26) | def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
function explorer (line 63) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py
function explorer (line 13) | def explorer(launcher):
FILE: audiocraft/grids/musicgen/musicgen_style_32khz.py
function explorer (line 11) | def explorer(launcher):
FILE: audiocraft/grids/watermarking/_explorers.py
class WatermarkingMbExplorer (line 12) | class WatermarkingMbExplorer(BaseExplorer):
method stages (line 15) | def stages(self):
method get_grid_meta (line 18) | def get_grid_meta(self):
method get_grid_metrics (line 27) | def get_grid_metrics(self):
class WatermarkingExplorer (line 66) | class WatermarkingExplorer(BaseExplorer):
method stages (line 69) | def stages(self):
method get_grid_meta (line 72) | def get_grid_meta(self):
method get_grid_metrics (line 81) | def get_grid_metrics(self):
FILE: audiocraft/grids/watermarking/audioseal.py
function explorer (line 15) | def explorer(launcher):
FILE: audiocraft/grids/watermarking/kbits.py
function explorer (line 16) | def explorer(launcher):
FILE: audiocraft/losses/balancer.py
class Balancer (line 14) | class Balancer:
method __init__ (line 61) | def __init__(self, weights: tp.Dict[str, float], balance_grads: bool =...
method metrics (line 74) | def metrics(self):
method backward (line 77) | def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Te...
FILE: audiocraft/losses/loudnessloss.py
function basic_loudness (line 18) | def basic_loudness(waveform: torch.Tensor, sample_rate: int) -> torch.Te...
function _unfold (line 53) | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Ten...
class FLoudnessRatio (line 69) | class FLoudnessRatio(nn.Module):
method __init__ (line 82) | def __init__(
method forward (line 101) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor...
class TLoudnessRatio (line 114) | class TLoudnessRatio(nn.Module):
method __init__ (line 125) | def __init__(
method forward (line 137) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor...
class TFLoudnessRatio (line 153) | class TFLoudnessRatio(nn.Module):
method __init__ (line 166) | def __init__(
method forward (line 187) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor...
FILE: audiocraft/losses/sisnr.py
function _unfold (line 15) | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Ten...
function _center (line 31) | def _center(x: torch.Tensor) -> torch.Tensor:
function _norm2 (line 35) | def _norm2(x: torch.Tensor) -> torch.Tensor:
class SISNR (line 39) | class SISNR(nn.Module):
method __init__ (line 56) | def __init__(
method forward (line 69) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor...
FILE: audiocraft/losses/specloss.py
class MelSpectrogramWrapper (line 18) | class MelSpectrogramWrapper(nn.Module):
method __init__ (line 35) | def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_lengt...
method forward (line 48) | def forward(self, x):
class MelSpectrogramL1Loss (line 65) | class MelSpectrogramL1Loss(torch.nn.Module):
method __init__ (line 80) | def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: in...
method forward (line 89) | def forward(self, x, y):
class MultiScaleMelSpectrogramLoss (line 96) | class MultiScaleMelSpectrogramLoss(nn.Module):
method __init__ (line 110) | def __init__(self, sample_rate: int, range_start: int = 6, range_end: ...
method forward (line 137) | def forward(self, x, y):
FILE: audiocraft/losses/stftloss.py
function _stft (line 17) | def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
class SpectralConvergenceLoss (line 45) | class SpectralConvergenceLoss(nn.Module):
method __init__ (line 48) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
method forward (line 52) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
class LogSTFTMagnitudeLoss (line 64) | class LogSTFTMagnitudeLoss(nn.Module):
method __init__ (line 70) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
method forward (line 74) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
class STFTLosses (line 86) | class STFTLosses(nn.Module):
method __init__ (line 97) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt...
method forward (line 109) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch....
class STFTLoss (line 129) | class STFTLoss(nn.Module):
method __init__ (line 142) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt...
method forward (line 151) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch....
class MRSTFTLoss (line 164) | class MRSTFTLoss(nn.Module):
method __init__ (line 177) | def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_l...
method forward (line 189) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
FILE: audiocraft/losses/wmloss.py
class WMDetectionLoss (line 13) | class WMDetectionLoss(nn.Module):
method __init__ (line 15) | def __init__(self, p_weight: float = 1.0, n_weight: float = 1.0) -> None:
method forward (line 21) | def forward(self, positive, negative, mask, message=None):
class WMMbLoss (line 55) | class WMMbLoss(nn.Module):
method __init__ (line 56) | def __init__(self, temperature: float, loss_type: Literal["bce", "mse"...
method forward (line 73) | def forward(self, positive, negative, mask, message):
FILE: audiocraft/metrics/chroma_cosinesim.py
class ChromaCosineSimilarityMetric (line 14) | class ChromaCosineSimilarityMetric(torchmetrics.Metric):
method __init__ (line 28) | def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, a...
method update (line 38) | def update(self, preds: torch.Tensor, targets: torch.Tensor,
method compute (line 69) | def compute(self) -> float:
FILE: audiocraft/metrics/clap_consistency.py
class TextConsistencyMetric (line 24) | class TextConsistencyMetric(torchmetrics.Metric):
method update (line 27) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch...
method compute (line 30) | def compute(self):
class CLAPTextConsistencyMetric (line 34) | class CLAPTextConsistencyMetric(TextConsistencyMetric):
method __init__ (line 47) | def __init__(self, model_path: tp.Union[str, Path], model_arch: str = ...
method _initialize_model (line 55) | def _initialize_model(self, model_path: tp.Union[str, Path], model_arc...
method _tokenizer (line 63) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
method update (line 67) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch...
method compute (line 81) | def compute(self):
FILE: audiocraft/metrics/fad.py
class FrechetAudioDistanceMetric (line 29) | class FrechetAudioDistanceMetric(torchmetrics.Metric):
method __init__ (line 145) | def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path...
method reset (line 167) | def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
method update (line 182) | def update(self, preds: torch.Tensor, targets: torch.Tensor,
method _get_samples_name (line 222) | def _get_samples_name(self, is_background: bool):
method _create_embedding_beams (line 225) | def _create_embedding_beams(self, is_background: bool, gpu_index: tp.O...
method _compute_fad_score (line 259) | def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
method _log_process_result (line 283) | def _log_process_result(self, returncode: int, log_file: tp.Union[Path...
method _parallel_create_embedding_beams (line 293) | def _parallel_create_embedding_beams(self, num_of_gpus: int):
method _sequential_create_embedding_beams (line 303) | def _sequential_create_embedding_beams(self):
method _local_compute_frechet_audio_distance (line 313) | def _local_compute_frechet_audio_distance(self):
method compute (line 323) | def compute(self) -> float:
FILE: audiocraft/metrics/kld.py
class _patch_passt_stft (line 22) | class _patch_passt_stft:
method __init__ (line 24) | def __init__(self):
method __enter__ (line 27) | def __enter__(self):
method __exit__ (line 32) | def __exit__(self, *exc):
function kl_divergence (line 36) | def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, ...
class KLDivergenceMetric (line 53) | class KLDivergenceMetric(torchmetrics.Metric):
method __init__ (line 62) | def __init__(self):
method _get_label_distribution (line 69) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
method update (line 82) | def update(self, preds: torch.Tensor, targets: torch.Tensor,
method compute (line 105) | def compute(self) -> dict:
class PasstKLDivergenceMetric (line 116) | class PasstKLDivergenceMetric(KLDivergenceMetric):
method __init__ (line 131) | def __init__(self, pretrained_length: tp.Optional[float] = None):
method _initialize_model (line 135) | def _initialize_model(self, pretrained_length: tp.Optional[float] = No...
method _load_base_model (line 145) | def _load_base_model(self, pretrained_length: tp.Optional[float]):
method _process_audio (line 172) | def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len:...
method _get_model_preds (line 187) | def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
method _get_label_distribution (line 198) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
FILE: audiocraft/metrics/miou.py
function calculate_miou (line 10) | def calculate_miou(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
FILE: audiocraft/metrics/pesq.py
class PesqMetric (line 14) | class PesqMetric(torchmetrics.Metric):
method __init__ (line 23) | def __init__(self, sample_rate: int):
method update (line 30) | def update(self, preds: torch.Tensor, targets: torch.Tensor):
method compute (line 45) | def compute(self) -> torch.Tensor:
FILE: audiocraft/metrics/rvm.py
function db_to_scale (line 13) | def db_to_scale(volume: tp.Union[float, torch.Tensor]):
function scale_to_db (line 17) | def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
class RelativeVolumeMel (line 22) | class RelativeVolumeMel(nn.Module):
method __init__ (line 69) | def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: ...
method forward (line 84) | def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) ...
FILE: audiocraft/metrics/visqol.py
class ViSQOL (line 22) | class ViSQOL:
method __init__ (line 56) | def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
method _get_target_sr (line 67) | def _get_target_sr(self, mode: str) -> int:
method _prepare_files (line 75) | def _prepare_files(
method _flush_files (line 132) | def _flush_files(self, tmp_dir: tp.Union[Path, str]):
method _collect_moslqo_score (line 136) | def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str])...
method _collect_debug_data (line 146) | def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) ->...
method visqol_model (line 153) | def visqol_model(self):
method _run_visqol (line 156) | def _run_visqol(
method __call__ (line 181) | def __call__(
FILE: audiocraft/models/audiogen.py
class AudioGen (line 23) | class AudioGen(BaseGenModel):
method __init__ (line 34) | def __init__(self, name: str, compression_model: CompressionModel, lm:...
method get_pretrained (line 40) | def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
method set_generation_params (line 63) | def set_generation_params(self, use_sampling: bool = True, top_k: int ...
FILE: audiocraft/models/builders.py
function get_quantizer (line 44) | def get_quantizer(
function get_encodec_autoencoder (line 56) | def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
function get_compression_model (line 70) | def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
function get_jasco_model (line 94) | def get_jasco_model(cfg: omegaconf.DictConfig,
function get_lm_model (line 136) | def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
function get_conditioner_provider (line 178) | def get_conditioner_provider(
function get_condition_fuser (line 230) | def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
function get_codebooks_pattern_provider (line 240) | def get_codebooks_pattern_provider(
function get_debug_compression_model (line 257) | def get_debug_compression_model(device="cpu", sample_rate: int = 32000):
function get_diffusion_model (line 291) | def get_diffusion_model(cfg: omegaconf.DictConfig):
function get_processor (line 298) | def get_processor(cfg, sample_rate: int = 24000):
function get_debug_lm_model (line 309) | def get_debug_lm_model(device="cpu"):
function get_wrapped_compression_model (line 338) | def get_wrapped_compression_model(
function get_watermark_model (line 354) | def get_watermark_model(cfg: omegaconf.DictConfig) -> WMModel:
FILE: audiocraft/models/encodec.py
class CompressionModel (line 28) | class CompressionModel(ABC, nn.Module):
method forward (line 34) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
method encode (line 38) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
method decode (line 43) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
method decode_latent (line 48) | def decode_latent(self, codes: torch.Tensor):
method channels (line 54) | def channels(self) -> int:
method frame_rate (line 59) | def frame_rate(self) -> float:
method sample_rate (line 64) | def sample_rate(self) -> int:
method cardinality (line 69) | def cardinality(self) -> int:
method num_codebooks (line 74) | def num_codebooks(self) -> int:
method total_codebooks (line 79) | def total_codebooks(self) -> int:
method set_num_codebooks (line 83) | def set_num_codebooks(self, n: int):
method get_pretrained (line 88) | def get_pretrained(
class EncodecModel (line 125) | class EncodecModel(CompressionModel):
method __init__ (line 144) | def __init__(self,
method total_codebooks (line 168) | def total_codebooks(self):
method num_codebooks (line 173) | def num_codebooks(self):
method set_num_codebooks (line 177) | def set_num_codebooks(self, n: int):
method cardinality (line 182) | def cardinality(self):
method preprocess (line 186) | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Opt...
method postprocess (line 198) | def postprocess(self,
method forward (line 206) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
method encode (line 223) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
method decode (line 240) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
method decode_latent (line 257) | def decode_latent(self, codes: torch.Tensor):
class DAC (line 262) | class DAC(CompressionModel):
method __init__ (line 263) | def __init__(self, model_type: str = "44khz"):
method forward (line 274) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
method encode (line 278) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
method decode (line 282) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
method decode_latent (line 287) | def decode_latent(self, codes: torch.Tensor):
method channels (line 292) | def channels(self) -> int:
method frame_rate (line 296) | def frame_rate(self) -> float:
method sample_rate (line 300) | def sample_rate(self) -> int:
method cardinality (line 304) | def cardinality(self) -> int:
method num_codebooks (line 308) | def num_codebooks(self) -> int:
method total_codebooks (line 312) | def total_codebooks(self) -> int:
method set_num_codebooks (line 315) | def set_num_codebooks(self, n: int):
class HFEncodecCompressionModel (line 323) | class HFEncodecCompressionModel(CompressionModel):
method __init__ (line 326) | def __init__(self, model: HFEncodecModel):
method forward (line 340) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
method encode (line 344) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
method decode (line 352) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
method decode_latent (line 360) | def decode_latent(self, codes: torch.Tensor):
method channels (line 365) | def channels(self) -> int:
method frame_rate (line 369) | def frame_rate(self) -> float:
method sample_rate (line 374) | def sample_rate(self) -> int:
method cardinality (line 378) | def cardinality(self) -> int:
method num_codebooks (line 382) | def num_codebooks(self) -> int:
method total_codebooks (line 386) | def total_codebooks(self) -> int:
method set_num_codebooks (line 389) | def set_num_codebooks(self, n: int):
class InterleaveStereoCompressionModel (line 397) | class InterleaveStereoCompressionModel(CompressionModel):
method __init__ (line 409) | def __init__(self, model: CompressionModel, per_timestep: bool = False):
method total_codebooks (line 416) | def total_codebooks(self):
method num_codebooks (line 420) | def num_codebooks(self):
method set_num_codebooks (line 428) | def set_num_codebooks(self, n: int):
method num_virtual_steps (line 436) | def num_virtual_steps(self) -> float:
method frame_rate (line 443) | def frame_rate(self) -> float:
method sample_rate (line 447) | def sample_rate(self) -> int:
method channels (line 451) | def channels(self) -> int:
method cardinality (line 455) | def cardinality(self):
method forward (line 460) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
method encode (line 463) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
method get_left_right_codes (line 481) | def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch....
method decode (line 488) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
method decode_latent (line 504) | def decode_latent(self, codes: torch.Tensor):
FILE: audiocraft/models/flow_matching.py
class FMOutput (line 35) | class FMOutput:
class CFGTerm (line 40) | class CFGTerm:
method __init__ (line 48) | def __init__(self, conditions, weight):
method drop_irrelevant_conds (line 52) | def drop_irrelevant_conds(self, conditions):
class AllCFGTerm (line 63) | class AllCFGTerm(CFGTerm):
method __init__ (line 67) | def __init__(self, conditions, weight):
method drop_irrelevant_conds (line 71) | def drop_irrelevant_conds(self):
class NullCFGTerm (line 75) | class NullCFGTerm(CFGTerm):
method __init__ (line 79) | def __init__(self, conditions, weight):
method drop_irrelevant_conds (line 83) | def drop_irrelevant_conds(self):
class TextCFGTerm (line 92) | class TextCFGTerm(CFGTerm):
method __init__ (line 97) | def __init__(self, conditions, weight, model_att_dropout):
method drop_irrelevant_conds (line 116) | def drop_irrelevant_conds(self):
class FlowMatchingModel (line 121) | class FlowMatchingModel(StreamingModule):
method __init__ (line 150) | def __init__(self, condition_provider: JascoConditioningProvider,
method _get_timestep_embedding (line 209) | def _get_timestep_embedding(self, timesteps, embedding_dim):
method _embed_time_parameter (line 232) | def _embed_time_parameter(self, t: torch.Tensor):
method _init_weights (line 244) | def _init_weights(self, weight_init: tp.Optional[str], depthwise_init:...
method _align_seq_length (line 276) | def _align_seq_length(self,
method forward (line 289) | def forward(self,
method _multi_source_cfg_preprocess (line 345) | def _multi_source_cfg_preprocess(self,
method estimated_vector_field (line 386) | def estimated_vector_field(self, z, t, condition_tensors=None, cfg_ter...
method _multi_source_cfg_postprocess (line 403) | def _multi_source_cfg_postprocess(self, v_thetas, cfg_terms):
method generate (line 419) | def generate(self,
FILE: audiocraft/models/genmodel.py
class BaseGenModel (line 28) | class BaseGenModel(ABC):
method __init__ (line 39) | def __init__(self, name: str, compression_model: CompressionModel, lm:...
method frame_rate (line 81) | def frame_rate(self) -> float:
method sample_rate (line 86) | def sample_rate(self) -> int:
method audio_channels (line 91) | def audio_channels(self) -> int:
method set_custom_progress_callback (line 95) | def set_custom_progress_callback(self, progress_callback: tp.Optional[...
method set_generation_params (line 100) | def set_generation_params(self, *args, **kwargs):
method get_pretrained (line 106) | def get_pretrained(name: str, device=None):
method _prepare_tokens_and_attributes (line 110) | def _prepare_tokens_and_attributes(
method generate_unconditional (line 135) | def generate_unconditional(self, num_samples: int, progress: bool = Fa...
method generate (line 151) | def generate(self, descriptions: tp.List[str], progress: bool = False,...
method generate_continuation (line 166) | def generate_continuation(self, prompt: torch.Tensor, prompt_sample_ra...
method _generate_tokens (line 193) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
method generate_audio (line 262) | def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
FILE: audiocraft/models/jasco.py
class JASCO (line 24) | class JASCO(BaseGenModel):
method __init__ (line 30) | def __init__(self, chords_mapping_path='assets/chord_to_index_mapping....
method get_pretrained (line 43) | def get_pretrained(name: str = 'facebook/jasco-chords-drums-400M', dev...
method set_generation_params (line 66) | def set_generation_params(self,
method _unnormalized_latents (line 85) | def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor:
method generate_audio (line 91) | def generate_audio(self, gen_latents: torch.Tensor) -> torch.Tensor:
method _generate_tokens (line 99) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
method _prepare_chord_conditions (line 137) | def _prepare_chord_conditions(
method _prepare_drums_conditions (line 176) | def _prepare_drums_conditions(self,
method _prepare_melody_conditions (line 214) | def _prepare_melody_conditions(
method _prepare_temporal_conditions (line 240) | def _prepare_temporal_conditions(
method generate_music (line 269) | def generate_music(
method generate (line 318) | def generate(self, descriptions: tp.List[str], progress: bool = False,...
FILE: audiocraft/models/lm.py
function get_init_fn (line 37) | def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int...
function init_layer (line 65) | def init_layer(m: nn.Module,
class ScaledEmbedding (line 98) | class ScaledEmbedding(nn.Embedding):
method __init__ (line 101) | def __init__(self, *args, lr=None, **kwargs):
method make_optim_group (line 105) | def make_optim_group(self):
class LMOutput (line 113) | class LMOutput:
class LMModel (line 120) | class LMModel(StreamingModule):
method __init__ (line 145) | def __init__(self, pattern_provider: CodebooksPatternProvider, conditi...
method _init_weights (line 179) | def _init_weights(self, weight_init: tp.Optional[str], depthwise_init:...
method special_token_id (line 214) | def special_token_id(self) -> int:
method num_codebooks (line 218) | def num_codebooks(self) -> int:
method forward (line 221) | def forward(self, sequence: torch.Tensor,
method compute_predictions (line 270) | def compute_predictions(
method _sample_next_token (line 323) | def _sample_next_token(self,
method generate (line 421) | def generate(self,
FILE: audiocraft/models/lm_magnet.py
class MagnetLMModel (line 26) | class MagnetLMModel(LMModel):
method __init__ (line 37) | def __init__(self, subcodes_context: int = 5, compression_model_framer...
method restricted_context_attn_mask (line 48) | def restricted_context_attn_mask(self, seq_len: int, device: torch.dev...
method _stage_attn_mask (line 69) | def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
method _build_attn_masks (line 102) | def _build_attn_masks(self, compression_model_framerate: int, segment_...
method generate (line 118) | def generate(self,
method _generate_magnet (line 152) | def _generate_magnet(self,
method _generate_stage (line 265) | def _generate_stage(self,
method _construct_spans_mask (line 442) | def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, dev...
method _least_probable_span_masking (line 461) | def _least_probable_span_masking(self, scores: torch.Tensor, num_maske...
FILE: audiocraft/models/loaders.py
function get_audiocraft_cache_dir (line 36) | def get_audiocraft_cache_dir() -> tp.Optional[str]:
function _get_state_dict (line 40) | def _get_state_dict(
function load_compression_model_ckpt (line 74) | def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], ...
function load_compression_model (line 78) | def load_compression_model(
function load_lm_model_ckpt (line 94) | def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir...
function _delete_param (line 98) | def _delete_param(cfg: DictConfig, full_name: str):
function load_lm_model (line 111) | def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', ...
function load_lm_model_magnet (line 129) | def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compres...
function load_jasco_model (line 158) | def load_jasco_model(file_or_url_or_id: tp.Union[Path, str],
function load_mbd_ckpt (line 175) | def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
function load_diffusion_models (line 181) | def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
function load_audioseal_models (line 206) | def load_audioseal_models(
FILE: audiocraft/models/magnet.py
class MAGNeT (line 18) | class MAGNeT(BaseGenModel):
method __init__ (line 23) | def __init__(self, **kwargs):
method get_pretrained (line 30) | def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=...
method set_generation_params (line 60) | def set_generation_params(self, use_sampling: bool = True, top_k: int ...
FILE: audiocraft/models/multibanddiffusion.py
class DiffusionProcess (line 25) | class DiffusionProcess:
method __init__ (line 32) | def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule...
method generate (line 36) | def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
class MultiBandDiffusion (line 48) | class MultiBandDiffusion:
method __init__ (line 55) | def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: Compre...
method sample_rate (line 61) | def sample_rate(self) -> int:
method get_mbd_musicgen (line 65) | def get_mbd_musicgen(device=None):
method get_mbd_24khz (line 81) | def get_mbd_24khz(bw: float = 3.0,
method get_condition (line 113) | def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch....
method get_emb (line 126) | def get_emb(self, codes: torch.Tensor):
method generate (line 133) | def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = ...
method re_eq (line 151) | def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 3...
method regenerate (line 167) | def regenerate(self, wav: torch.Tensor, sample_rate: int):
method tokens_to_wav (line 182) | def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
FILE: audiocraft/models/musicgen.py
class MusicGen (line 40) | class MusicGen(BaseGenModel):
method __init__ (line 51) | def __init__(self, name: str, compression_model: CompressionModel, lm:...
method get_pretrained (line 57) | def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
method set_generation_params (line 96) | def set_generation_params(self, use_sampling: bool = True, top_k: int ...
method set_style_conditioner_params (line 134) | def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length...
method generate_with_chroma (line 155) | def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs...
method _prepare_tokens_and_attributes (line 194) | def _prepare_tokens_and_attributes(
method _generate_tokens (line 251) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
FILE: audiocraft/models/unet.py
class Output (line 21) | class Output:
function get_model (line 25) | def get_model(cfg, channels: int, side: int, num_steps: int):
class ResBlock (line 33) | class ResBlock(nn.Module):
method __init__ (line 34) | def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
method forward (line 52) | def forward(self, x):
class DecoderLayer (line 58) | class DecoderLayer(nn.Module):
method __init__ (line 59) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int...
method forward (line 72) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class EncoderLayer (line 80) | class EncoderLayer(nn.Module):
method __init__ (line 81) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int...
method forward (line 94) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class BLSTM (line 107) | class BLSTM(nn.Module):
method __init__ (line 110) | def __init__(self, dim, layers=2):
method forward (line 115) | def forward(self, x):
class DiffusionUnet (line 123) | class DiffusionUnet(nn.Module):
method __init__ (line 124) | def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, gr...
method forward (line 163) | def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], ...
FILE: audiocraft/models/watermark.py
class WMModel (line 17) | class WMModel(ABC, nn.Module):
method get_watermark (line 24) | def get_watermark(
method detect_watermark (line 36) | def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
class AudioSeal (line 49) | class AudioSeal(WMModel):
method __init__ (line 54) | def __init__(
method get_watermark (line 67) | def get_watermark(
method detect_watermark (line 75) | def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
method forward (line 93) | def forward( # generator
method get_pretrained (line 105) | def get_pretrained(name="base", device=None) -> WMModel:
FILE: audiocraft/modules/activations.py
class CustomGLU (line 13) | class CustomGLU(nn.Module):
method __init__ (line 33) | def __init__(self, activation: nn.Module, dim: int = -1):
method forward (line 38) | def forward(self, x: Tensor):
class SwiGLU (line 44) | class SwiGLU(CustomGLU):
method __init__ (line 52) | def __init__(self, dim: int = -1):
class GeGLU (line 56) | class GeGLU(CustomGLU):
method __init__ (line 64) | def __init__(self, dim: int = -1):
class ReGLU (line 68) | class ReGLU(CustomGLU):
method __init__ (line 76) | def __init__(self, dim: int = -1):
function get_activation_fn (line 80) | def get_activation_fn(
FILE: audiocraft/modules/chroma.py
class ChromaExtractor (line 16) | class ChromaExtractor(nn.Module):
method __init__ (line 29) | def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: i...
method forward (line 46) | def forward(self, wav: torch.Tensor) -> torch.Tensor:
FILE: audiocraft/modules/codebooks_patterns.py
class Pattern (line 22) | class Pattern:
method __post_init__ (line 50) | def __post_init__(self):
method _validate_layout (line 57) | def _validate_layout(self):
method num_sequence_steps (line 79) | def num_sequence_steps(self):
method max_delay (line 83) | def max_delay(self):
method valid_layout (line 91) | def valid_layout(self):
method starts_with_special_token (line 95) | def starts_with_special_token(self):
method get_sequence_coords_with_timestep (line 98) | def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int...
method get_steps_with_timestep (line 113) | def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) ...
method get_first_step_with_timesteps (line 116) | def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = ...
method _build_pattern_sequence_scatter_indexes (line 120) | def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q:...
method build_pattern_sequence (line 154) | def build_pattern_sequence(self, z: torch.Tensor, special_token: int, ...
method _build_reverted_sequence_scatter_indexes (line 181) | def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int...
method revert_pattern_sequence (line 225) | def revert_pattern_sequence(self, s: torch.Tensor, special_token: int,...
method revert_pattern_logits (line 250) | def revert_pattern_logits(self, logits: torch.Tensor, special_token: f...
class CodebooksPatternProvider (line 272) | class CodebooksPatternProvider(ABC):
method __init__ (line 290) | def __init__(self, n_q: int, cached: bool = True):
method get_pattern (line 296) | def get_pattern(self, timesteps: int) -> Pattern:
class DelayedPatternProvider (line 305) | class DelayedPatternProvider(CodebooksPatternProvider):
method __init__ (line 328) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
method get_pattern (line 339) | def get_pattern(self, timesteps: int) -> Pattern:
class ParallelPatternProvider (line 359) | class ParallelPatternProvider(DelayedPatternProvider):
method __init__ (line 368) | def __init__(self, n_q: int, empty_initial: int = 0):
class UnrolledPatternProvider (line 372) | class UnrolledPatternProvider(CodebooksPatternProvider):
method __init__ (line 423) | def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = N...
method _build_flattened_codebooks (line 437) | def _build_flattened_codebooks(self, delays: tp.List[int], flattening:...
method _num_inner_steps (line 457) | def _num_inner_steps(self):
method num_virtual_steps (line 462) | def num_virtual_steps(self, timesteps: int) -> int:
method get_pattern (line 465) | def get_pattern(self, timesteps: int) -> Pattern:
class CoarseFirstPattern (line 493) | class CoarseFirstPattern(CodebooksPatternProvider):
method __init__ (line 507) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
method get_pattern (line 515) | def get_pattern(self, timesteps: int) -> Pattern:
class MusicLMPattern (line 530) | class MusicLMPattern(CodebooksPatternProvider):
method __init__ (line 538) | def __init__(self, n_q: int, group_by: int = 2):
method get_pattern (line 542) | def get_pattern(self, timesteps: int) -> Pattern:
FILE: audiocraft/modules/conditioners.py
class JascoCondConst (line 46) | class JascoCondConst(Enum):
class WavCondition (line 55) | class WavCondition(tp.NamedTuple):
class JointEmbedCondition (line 63) | class JointEmbedCondition(tp.NamedTuple):
class SymbolicCondition (line 72) | class SymbolicCondition(tp.NamedTuple):
class ConditioningAttributes (line 78) | class ConditioningAttributes:
method __getitem__ (line 84) | def __getitem__(self, item):
method text_attributes (line 88) | def text_attributes(self):
method wav_attributes (line 92) | def wav_attributes(self):
method joint_embed_attributes (line 96) | def joint_embed_attributes(self):
method symbolic_attributes (line 100) | def symbolic_attributes(self):
method attributes (line 104) | def attributes(self):
method to_flat_dict (line 112) | def to_flat_dict(self):
method from_flat_dict (line 121) | def from_flat_dict(cls, x):
class SegmentWithAttributes (line 129) | class SegmentWithAttributes(SegmentInfo):
method to_condition_attributes (line 134) | def to_condition_attributes(self) -> ConditioningAttributes:
function nullify_condition (line 138) | def nullify_condition(condition: ConditionType, dim: int = 1):
function nullify_wav (line 165) | def nullify_wav(cond: WavCondition) -> WavCondition:
function nullify_joint_embed (line 184) | def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
function nullify_chords (line 201) | def nullify_chords(sym_cond: SymbolicCondition, null_chord_idx: int = 19...
function nullify_melody (line 212) | def nullify_melody(sym_cond: SymbolicCondition) -> SymbolicCondition:
function _drop_description_condition (line 223) | def _drop_description_condition(conditions: tp.List[ConditioningAttribut...
class Tokenizer (line 239) | class Tokenizer:
method __call__ (line 243) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch...
class WhiteSpaceTokenizer (line 247) | class WhiteSpaceTokenizer(Tokenizer):
method __init__ (line 256) | def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_...
method __call__ (line 269) | def __call__(self, texts: tp.List[tp.Optional[str]],
class NoopTokenizer (line 315) | class NoopTokenizer(Tokenizer):
method __init__ (line 325) | def __init__(self, n_bins: int, pad_idx: int = 0):
method __call__ (line 329) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch...
class BaseConditioner (line 345) | class BaseConditioner(nn.Module):
method __init__ (line 355) | def __init__(self, dim: int, output_dim: int):
method tokenize (line 362) | def tokenize(self, *args, **kwargs) -> tp.Any:
method forward (line 370) | def forward(self, inputs: tp.Any) -> ConditionType:
class TextConditioner (line 383) | class TextConditioner(BaseConditioner):
class LUTConditioner (line 387) | class LUTConditioner(TextConditioner):
method __init__ (line 397) | def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: ...
method tokenize (line 408) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Ten...
method forward (line 414) | def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> Con...
class T5Conditioner (line 422) | class T5Conditioner(TextConditioner):
method __init__ (line 450) | def __init__(self, name: str, output_dim: int, finetune: bool, device:...
method tokenize (line 490) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch...
method forward (line 509) | def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
class WaveformConditioner (line 518) | class WaveformConditioner(BaseConditioner):
method __init__ (line 529) | def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.d...
method tokenize (line 535) | def tokenize(self, x: WavCondition) -> WavCondition:
method _get_wav_embedding (line 540) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
method _downsampling_factor (line 544) | def _downsampling_factor(self):
method forward (line 548) | def forward(self, x: WavCondition) -> ConditionType:
class ChromaStemConditioner (line 571) | class ChromaStemConditioner(WaveformConditioner):
method __init__ (line 593) | def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, r...
method _downsampling_factor (line 618) | def _downsampling_factor(self) -> int:
method _load_eval_wavs (line 621) | def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) ->...
method reset_eval_wavs (line 642) | def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
method has_eval_wavs (line 645) | def has_eval_wavs(self) -> bool:
method _sample_eval_wavs (line 648) | def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
method _get_chroma_len (line 657) | def _get_chroma_len(self) -> int:
method _get_stemmed_wav (line 664) | def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> tor...
method _extract_chroma (line 678) | def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
method _compute_wav_embedding (line 684) | def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) ...
method _get_full_chroma_for_cache (line 694) | def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: Wav...
method _extract_chroma_chunk (line 702) | def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondi...
method _get_wav_embedding (line 718) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
method tokenize (line 752) | def tokenize(self, x: WavCondition) -> WavCondition:
class FeatureExtractor (line 762) | class FeatureExtractor(WaveformConditioner):
method __init__ (line 790) | def __init__(
method _get_wav_embedding (line 827) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
method _downsampling_factor (line 854) | def _downsampling_factor(self):
method _get_mask_wav (line 860) | def _get_mask_wav(self, x: WavCondition, start: int) -> tp.Union[torch...
class StyleConditioner (line 872) | class StyleConditioner(FeatureExtractor):
method __init__ (line 897) | def __init__(self, transformer_scale: str = 'default', ds_factor: int ...
method _get_wav_embedding (line 937) | def _get_wav_embedding(self, wav: WavCondition) -> torch.Tensor:
method set_params (line 970) | def set_params(self, eval_q: int = 3,
method _downsampling_factor (line 987) | def _downsampling_factor(self):
method forward (line 991) | def forward(self, x: WavCondition) -> ConditionType:
class JointEmbeddingConditioner (line 1006) | class JointEmbeddingConditioner(BaseConditioner):
method __init__ (line 1020) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ...
method _get_embed (line 1039) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,...
method forward (line 1048) | def forward(self, x: JointEmbedCondition) -> ConditionType:
method tokenize (line 1063) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
class CLAPEmbeddingConditioner (line 1067) | class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
method __init__ (line 1094) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ...
method _tokenizer (line 1135) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
method _compute_text_embedding (line 1139) | def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
method _get_text_embedding_for_cache (line 1151) | def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
method _preprocess_wav (line 1158) | def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sam...
method _compute_wav_embedding (line 1179) | def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
method _get_wav_embedding_for_cache (line 1214) | def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
method _extract_wav_embedding_chunk (line 1230) | def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: Jo...
method _get_text_embedding (line 1251) | def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
method _get_wav_embedding (line 1265) | def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
method tokenize (line 1278) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
method _get_embed (line 1291) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,...
function dropout_symbolic_conditions (line 1304) | def dropout_symbolic_conditions(sample: ConditioningAttributes,
function dropout_condition (line 1337) | def dropout_condition(sample: ConditioningAttributes,
class DropoutModule (line 1372) | class DropoutModule(nn.Module):
method __init__ (line 1374) | def __init__(self, seed: int = 1234):
class AttributeDropout (line 1380) | class AttributeDropout(DropoutModule):
method __init__ (line 1397) | def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eva...
method forward (line 1405) | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List...
method __repr__ (line 1423) | def __repr__(self):
class ClassifierFreeGuidanceDropout (line 1427) | class ClassifierFreeGuidanceDropout(DropoutModule):
method __init__ (line 1435) | def __init__(self, p: float, seed: int = 1234):
method forward (line 1439) | def forward(self, samples: tp.List[ConditioningAttributes],
method __repr__ (line 1465) | def __repr__(self):
class ConditioningProvider (line 1469) | class ConditioningProvider(nn.Module):
method __init__ (line 1476) | def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device...
method joint_embed_conditions (line 1482) | def joint_embed_conditions(self):
method has_joint_embed_conditions (line 1486) | def has_joint_embed_conditions(self):
method text_conditions (line 1490) | def text_conditions(self):
method wav_conditions (line 1494) | def wav_conditions(self):
method has_wav_condition (line 1498) | def has_wav_condition(self):
method tokenize (line 1501) | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict...
method forward (line 1529) | def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, Con...
method _collate_text (line 1547) | def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> t...
method _collate_wavs (line 1574) | def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> t...
method _collate_joint_embeds (line 1618) | def _collate_joint_embeds(self, samples: tp.List[ConditioningAttribute...
class ConditionFuser (line 1672) | class ConditionFuser(StreamingModule):
method __init__ (line 1689) | def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attent...
method forward (line 1703) | def forward(
FILE: audiocraft/modules/conv.py
function apply_parametrization_norm (line 21) | def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
function get_norm_module (line 33) | def get_norm_module(module: nn.Module, causal: bool = False, norm: str =...
function get_extra_padding_for_conv1d (line 47) | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stri...
function pad_for_conv1d (line 56) | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, paddi...
function pad1d (line 71) | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'co...
function unpad1d (line 91) | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
class NormConv1d (line 100) | class NormConv1d(nn.Module):
method __init__ (line 104) | def __init__(self, *args, causal: bool = False, norm: str = 'none',
method forward (line 111) | def forward(self, x):
class NormConv2d (line 117) | class NormConv2d(nn.Module):
method __init__ (line 121) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str...
method forward (line 127) | def forward(self, x):
class NormConvTranspose1d (line 133) | class NormConvTranspose1d(nn.Module):
method __init__ (line 137) | def __init__(self, *args, causal: bool = False, norm: str = 'none',
method forward (line 144) | def forward(self, x):
class NormConvTranspose2d (line 150) | class NormConvTranspose2d(nn.Module):
method __init__ (line 154) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str...
method forward (line 159) | def forward(self, x):
class StreamableConv1d (line 165) | class StreamableConv1d(nn.Module):
method __init__ (line 169) | def __init__(self, in_channels: int, out_channels: int,
method forward (line 185) | def forward(self, x):
class StreamableConvTranspose1d (line 204) | class StreamableConvTranspose1d(nn.Module):
method __init__ (line 208) | def __init__(self, in_channels: int, out_channels: int,
method forward (line 221) | def forward(self, x):
FILE: audiocraft/modules/diffusion_schedule.py
function betas_from_alpha_bar (line 20) | def betas_from_alpha_bar(alpha_bar):
class SampleProcessor (line 25) | class SampleProcessor(torch.nn.Module):
method project_sample (line 26) | def project_sample(self, x: torch.Tensor):
method return_sample (line 30) | def return_sample(self, z: torch.Tensor):
class MultiBandProcessor (line 35) | class MultiBandProcessor(SampleProcessor):
method __init__ (line 57) | def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
method mean (line 77) | def mean(self):
method std (line 82) | def std(self):
method target_std (line 87) | def target_std(self):
method project_sample (line 91) | def project_sample(self, x: torch.Tensor):
method return_sample (line 104) | def return_sample(self, x: torch.Tensor):
class NoiseSchedule (line 112) | class NoiseSchedule:
method __init__ (line 127) | def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_s...
method get_beta (line 149) | def get_beta(self, step: tp.Union[int, torch.Tensor]):
method get_initial_noise (line 155) | def get_initial_noise(self, x: torch.Tensor):
method get_alpha_bar (line 160) | def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]]...
method get_training_item (line 169) | def get_training_item(self, x: torch.Tensor, tensor_step: bool = False...
method generate (line 192) | def generate(self, model: torch.nn.Module, initial: tp.Optional[torch....
method generate_subsampled (line 238) | def generate_subsampled(self, model: torch.nn.Module, initial: torch.T...
FILE: audiocraft/modules/jasco_conditioners.py
class MelodyConditioner (line 15) | class MelodyConditioner(BaseConditioner):
method __init__ (line 23) | def __init__(self, card: int, out_dim: int, device: tp.Union[torch.dev...
method tokenize (line 27) | def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
method forward (line 30) | def forward(self, x: SymbolicCondition) -> ConditionType:
class ChordsEmbConditioner (line 36) | class ChordsEmbConditioner(BaseConditioner):
method __init__ (line 44) | def __init__(self, card: int, out_dim: int, device: tp.Union[torch.dev...
method tokenize (line 50) | def tokenize(self, x: SymbolicCondition) -> SymbolicCondition:
method forward (line 53) | def forward(self, x: SymbolicCondition) -> ConditionType:
class DrumsConditioner (line 59) | class DrumsConditioner(WaveformConditioner):
method __init__ (line 60) | def __init__(self, out_dim: int, sample_rate: int, blurring_factor: in...
method create_embedding_cache (line 93) | def create_embedding_cache(self, cache_path):
method _get_drums_stem (line 100) | def _get_drums_stem(self, wav: torch.Tensor, sample_rate: int) -> torc...
method _temporal_blur (line 111) | def _temporal_blur(self, z: torch.Tensor):
method _extract_coarse_drum_codes (line 125) | def _extract_coarse_drum_codes(self, wav: torch.Tensor, sample_rate: i...
method _calc_coarse_drum_codes_for_cache (line 140) | def _calc_coarse_drum_codes_for_cache(self, path: tp.Union[str, Path],
method _load_drum_codes_chunk (line 161) | def _load_drum_codes_chunk(self, full_coarse_drum_codes: torch.Tensor,...
method _get_wav_embedding (line 179) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
method tokenize (line 206) | def tokenize(self, x: WavCondition) -> WavCondition:
class JascoConditioningProvider (line 216) | class JascoConditioningProvider(ConditioningProvider):
method __init__ (line 224) | def __init__(self, *args,
method tokenize (line 233) | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict...
method _collate_symbolic (line 262) | def _collate_symbolic(self, samples: tp.List[ConditioningAttributes],
FILE: audiocraft/modules/lstm.py
class StreamableLSTM (line 10) | class StreamableLSTM(nn.Module):
method __init__ (line 14) | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = T...
method forward (line 19) | def forward(self, x):
FILE: audiocraft/modules/rope.py
class XPos (line 13) | class XPos(nn.Module):
method __init__ (line 24) | def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int =...
method get_decay (line 38) | def get_decay(self, start: int, end: int):
class RotaryEmbedding (line 49) | class RotaryEmbedding(nn.Module):
method __init__ (line 60) | def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool =...
method get_rotation (line 75) | def get_rotation(self, start: int, end: int):
method rotate (line 84) | def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, i...
method rotate_qk (line 106) | def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int...
FILE: audiocraft/modules/seanet.py
class SEANetResnetBlock (line 16) | class SEANetResnetBlock(nn.Module):
method __init__ (line 33) | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dila...
method forward (line 59) | def forward(self, x):
class SEANetEncoder (line 63) | class SEANetEncoder(nn.Module):
method __init__ (line 91) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:...
method forward (line 152) | def forward(self, x):
class SEANetDecoder (line 156) | class SEANetDecoder(nn.Module):
method __init__ (line 186) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:...
method forward (line 256) | def forward(self, z):
FILE: audiocraft/modules/streaming.py
class StreamingModule (line 20) | class StreamingModule(nn.Module):
method __init__ (line 43) | def __init__(self) -> None:
method _apply_named_streaming (line 48) | def _apply_named_streaming(self, fn: tp.Any):
method _set_streaming (line 53) | def _set_streaming(self, streaming: bool):
method streaming (line 59) | def streaming(self):
method reset_streaming (line 68) | def reset_streaming(self):
method get_streaming_state (line 75) | def get_streaming_state(self) -> State:
method set_streaming_state (line 88) | def set_streaming_state(self, state: State):
method flush (line 107) | def flush(self, x: tp.Optional[torch.Tensor] = None):
class StreamingSequential (line 122) | class StreamingSequential(StreamingModule, nn.Sequential):
method flush (line 125) | def flush(self, x: tp.Optional[torch.Tensor] = None):
FILE: audiocraft/modules/transformer.py
function set_efficient_attention_backend (line 31) | def set_efficient_attention_backend(backend: str = 'torch'):
function _get_attention_time_dimension (line 38) | def _get_attention_time_dimension(memory_efficient: bool) -> int:
function _is_profiled (line 45) | def _is_profiled() -> bool:
function create_norm_fn (line 54) | def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
function create_sin_embedding (line 70) | def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: ...
function expand_repeated_kv (line 92) | def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bo...
class LayerScale (line 112) | class LayerScale(nn.Module):
method __init__ (line 123) | def __init__(self, channels: int, init: float = 1e-4, channel_last: bo...
method forward (line 131) | def forward(self, x: torch.Tensor):
class StreamingMultiheadAttention (line 138) | class StreamingMultiheadAttention(StreamingModule):
method __init__ (line 164) | def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0....
method _load_from_state_dict (line 224) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
method _get_mask (line 233) | def _get_mask(self, current_steps: int, device: torch.device, dtype: t...
method _complete_kv (line 266) | def _complete_kv(self, k, v):
method _apply_rope (line 300) | def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
method forward (line 315) | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch...
class StreamingTransformerLayer (line 454) | class StreamingTransformerLayer(nn.TransformerEncoderLayer):
method __init__ (line 488) | def __init__(self, d_model: int, num_heads: int, dim_feedforward: int ...
method _cross_attention_block (line 542) | def _cross_attention_block(self, src: torch.Tensor,
method forward (line 550) | def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tenso...
class StreamingTransformer (line 577) | class StreamingTransformer(StreamingModule):
method __init__ (line 614) | def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_...
method _apply_layer (line 662) | def _apply_layer(self, layer, *args, **kwargs):
method forward (line 693) | def forward(self, x: torch.Tensor, *args, **kwargs):
method make_optim_group (line 715) | def make_optim_group(self):
function _verify_xformers_memory_efficient_compat (line 726) | def _verify_xformers_memory_efficient_compat():
function _verify_xformers_internal_compat (line 740) | def _verify_xformers_internal_compat():
function _is_custom (line 754) | def _is_custom(custom: bool, memory_efficient: bool):
FILE: audiocraft/modules/unet_transformer.py
class UnetTransformer (line 6) | class UnetTransformer(StreamingTransformer):
method __init__ (line 20) | def __init__(self, d_model: int, num_layers: int, skip_connections: bo...
method forward (line 32) | def forward(self, x: torch.Tensor, *args, **kwargs):
FILE: audiocraft/modules/watermark.py
function pad (line 13) | def pad(
function mix (line 42) | def mix(
FILE: audiocraft/optim/cosine_lr_scheduler.py
class CosineLRScheduler (line 13) | class CosineLRScheduler(_LRScheduler):
method __init__ (line 23) | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_step...
method _get_sched_lr (line 33) | def _get_sched_lr(self, lr: float, step: int):
method get_lr (line 47) | def get_lr(self):
FILE: audiocraft/optim/dadam.py
function to_real (line 19) | def to_real(x):
class DAdaptAdam (line 26) | class DAdaptAdam(torch.optim.Optimizer):
method __init__ (line 58) | def __init__(self, params, lr=1.0,
method supports_memory_efficient_fp16 (line 95) | def supports_memory_efficient_fp16(self):
method supports_flat_params (line 99) | def supports_flat_params(self):
method step (line 102) | def step(self, closure=None):
FILE: audiocraft/optim/ema.py
function _get_all_non_persistent_buffers_set (line 17) | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "...
function _get_named_tensors (line 32) | def _get_named_tensors(module: nn.Module):
class ModuleDictEMA (line 40) | class ModuleDictEMA:
method __init__ (line 45) | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
method _init (line 55) | def _init(self):
method step (line 64) | def step(self):
method state_dict (line 78) | def state_dict(self):
method load_state_dict (line 81) | def load_state_dict(self, state):
FILE: audiocraft/optim/fsdp.py
function is_fsdp_used (line 22) | def is_fsdp_used() -> bool:
function is_sharded_tensor (line 32) | def is_sharded_tensor(x: tp.Any) -> bool:
function switch_to_full_state_dict (line 37) | def switch_to_full_state_dict(models: tp.List[FSDP]):
function wrap_with_fsdp (line 51) | def wrap_with_fsdp(cfg, model: torch.nn.Module,
function purge_fsdp (line 120) | def purge_fsdp(model: FSDP):
class _FSDPFixStateDict (line 149) | class _FSDPFixStateDict(FSDP):
method _name_without_fsdp_prefix (line 151) | def _name_without_fsdp_prefix(name: str) -> str:
method state_dict (line 157) | def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type...
method load_state_dict (line 164) | def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore
function _fix_post_backward_hook (line 186) | def _fix_post_backward_hook():
FILE: audiocraft/optim/inverse_sqrt_lr_scheduler.py
class InverseSquareRootLRScheduler (line 13) | class InverseSquareRootLRScheduler(_LRScheduler):
method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini...
method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int):
method get_lr (line 37) | def get_lr(self):
FILE: audiocraft/optim/linear_warmup_lr_scheduler.py
class LinearWarmupLRScheduler (line 13) | class LinearWarmupLRScheduler(_LRScheduler):
method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini...
method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int):
method get_lr (line 34) | def get_lr(self):
FILE: audiocraft/optim/polynomial_decay_lr_scheduler.py
class PolynomialDecayLRScheduler (line 11) | class PolynomialDecayLRScheduler(_LRScheduler):
method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_step...
method _get_sched_lr (line 31) | def _get_sched_lr(self, lr: float, step: int):
method get_lr (line 46) | def get_lr(self):
FILE: audiocraft/quantization/base.py
class QuantizedResult (line 19) | class QuantizedResult:
class BaseQuantizer (line 27) | class BaseQuantizer(nn.Module):
method forward (line 31) | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
method encode (line 40) | def encode(self, x: torch.Tensor) -> torch.Tensor:
method decode (line 44) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
method total_codebooks (line 49) | def total_codebooks(self):
method num_codebooks (line 54) | def num_codebooks(self):
method set_num_codebooks (line 58) | def set_num_codebooks(self, n: int):
class DummyQuantizer (line 63) | class DummyQuantizer(BaseQuantizer):
method __init__ (line 66) | def __init__(self):
method forward (line 69) | def forward(self, x: torch.Tensor, frame_rate: int):
method encode (line 73) | def encode(self, x: torch.Tensor) -> torch.Tensor:
method decode (line 80) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
method total_codebooks (line 88) | def total_codebooks(self):
method num_codebooks (line 93) | def num_codebooks(self):
method set_num_codebooks (line 97) | def set_num_codebooks(self, n: int):
FILE: audiocraft/quantization/core_vq.py
function exists (line 16) | def exists(val: tp.Optional[tp.Any]) -> bool:
function default (line 20) | def default(val: tp.Any, d: tp.Any) -> tp.Any:
function l2norm (line 24) | def l2norm(t):
function ema_inplace (line 28) | def ema_inplace(moving_avg, new, decay: float):
function laplace_smoothing (line 32) | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
function uniform_init (line 36) | def uniform_init(*shape: int):
function sample_vectors (line 42) | def sample_vectors(samples, num: int):
function kmeans (line 53) | def kmeans(samples, num_clusters: int, num_iters: int = 10):
function orthogonal_loss_fn (line 78) | def orthogonal_loss_fn(t):
class EuclideanCodebook (line 87) | class EuclideanCodebook(nn.Module):
method __init__ (line 103) | def __init__(
method init_embed_ (line 130) | def init_embed_(self, data):
method replace_ (line 142) | def replace_(self, samples, mask):
method expire_codes_ (line 148) | def expire_codes_(self, batch_samples):
method preprocess (line 160) | def preprocess(self, x):
method quantize (line 164) | def quantize(self, x):
method postprocess_emb (line 174) | def postprocess_emb(self, embed_ind, shape):
method dequantize (line 177) | def dequantize(self, embed_ind):
method encode (line 181) | def encode(self, x):
method decode (line 191) | def decode(self, embed_ind):
method forward (line 195) | def forward(self, x):
class VectorQuantization (line 222) | class VectorQuantization(nn.Module):
method __init__ (line 244) | def __init__(
method codebook (line 283) | def codebook(self):
method inited (line 287) | def inited(self):
method _preprocess (line 290) | def _preprocess(self, x):
method _postprocess (line 295) | def _postprocess(self, quantize):
method encode (line 300) | def encode(self, x):
method decode (line 306) | def decode(self, embed_ind):
method forward (line 312) | def forward(self, x):
class ResidualVectorQuantization (line 351) | class ResidualVectorQuantization(nn.Module):
method __init__ (line 356) | def __init__(self, *, num_quantizers, **kwargs):
method forward (line 362) | def forward(self, x, n_q: tp.Optional[int] = None):
method encode (line 386) | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> tor...
method decode (line 398) | def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
FILE: audiocraft/quantization/vq.py
class ResidualVectorQuantizer (line 16) | class ResidualVectorQuantizer(BaseQuantizer):
method __init__ (line 35) | def __init__(
method forward (line 76) | def forward(self, x: torch.Tensor, frame_rate: int):
method encode (line 87) | def encode(self, x: torch.Tensor) -> torch.Tensor:
method decode (line 98) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
method total_codebooks (line 106) | def total_codebooks(self):
method num_codebooks (line 110) | def num_codebooks(self):
method set_num_codebooks (line 113) | def set_num_codebooks(self, n: int):
FILE: audiocraft/solvers/audiogen.py
class AudioGenSolver (line 10) | class AudioGenSolver(musicgen.MusicGenSolver):
FILE: audiocraft/solvers/base.py
class StandardSolver (line 27) | class StandardSolver(ABC, flashy.BaseSolver):
method __init__ (line 38) | def __init__(self, cfg: omegaconf.DictConfig):
method autocast (line 98) | def autocast(self):
method _get_state_source (line 102) | def _get_state_source(self, name) -> flashy.state.StateDictSource:
method best_metric_name (line 107) | def best_metric_name(self) -> tp.Optional[str]:
method register_best_state (line 114) | def register_best_state(self, *args: str):
method register_ema (line 127) | def register_ema(self, *args: str):
method wrap_with_fsdp (line 141) | def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
method update_best_state_from_stage (line 147) | def update_best_state_from_stage(self, stage_name: str = 'valid'):
method _load_new_state_dict (line 189) | def _load_new_state_dict(self, state_dict: dict) -> dict:
method swap_best_state (line 198) | def swap_best_state(self):
method swap_ema_state (line 210) | def swap_ema_state(self):
method is_training (line 226) | def is_training(self):
method log_model_summary (line 229) | def log_model_summary(self, model: nn.Module):
method build_model (line 236) | def build_model(self):
method initialize_ema (line 240) | def initialize_ema(self):
method build_dataloaders (line 256) | def build_dataloaders(self):
method show (line 261) | def show(self):
method log_updates (line 266) | def log_updates(self):
method checkpoint_path (line 270) | def checkpoint_path(self, **kwargs):
method epoch_checkpoint_path (line 274) | def epoch_checkpoint_path(self, epoch: int, **kwargs):
method checkpoint_path_with_name (line 278) | def checkpoint_path_with_name(self, name: str, **kwargs):
method save_checkpoints (line 282) | def save_checkpoints(self):
method load_from_pretrained (line 311) | def load_from_pretrained(self, name: str) -> dict:
method load_checkpoints (line 314) | def load_checkpoints(self, load_best: bool = False, ignore_state_keys:...
method restore (line 432) | def restore(self, load_best: bool = False, replay_metrics: bool = False,
method commit (line 456) | def commit(self, save_checkpoints: bool = True):
method run_epoch (line 466) | def run_epoch(self):
method run (line 489) | def run(self):
method should_stop_training (line 501) | def should_stop_training(self) -> bool:
method should_run_stage (line 505) | def should_run_stage(self, stage_name) -> bool:
method run_step (line 513) | def run_step(self, idx: int, batch: tp.Any, metrics: dict):
method common_train_valid (line 517) | def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
method train (line 559) | def train(self):
method valid (line 563) | def valid(self):
method evaluate (line 568) | def evaluate(self):
method generate (line 573) | def generate(self):
method run_one_stage (line 577) | def run_one_stage(self, stage_name: str):
method get_eval_solver_from_sig (line 597) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
FILE: audiocraft/solvers/builders.py
class DatasetType (line 37) | class DatasetType(Enum):
function get_solver (line 44) | def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
function get_optim_parameter_groups (line 68) | def get_optim_parameter_groups(model: nn.Module):
function get_optimizer (line 95) | def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]]...
function get_lr_scheduler (line 124) | def get_lr_scheduler(optimizer: torch.optim.Optimizer,
function get_ema (line 168) | def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp...
function get_loss (line 189) | def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
function get_balancer (line 206) | def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictC...
function get_adversary (line 212) | def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
function get_adversarial_losses (line 223) | def get_adversarial_losses(cfg) -> nn.ModuleDict:
function get_visqol (line 256) | def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
function get_fad (line 262) | def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMe...
function get_kldiv (line 270) | def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
function get_text_consistency (line 280) | def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsi...
function get_chroma_cosine_similarity (line 290) | def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.C...
function get_audio_datasets (line 297) | def get_audio_datasets(cfg: omegaconf.DictConfig,
FILE: audiocraft/solvers/compression.py
class CompressionSolver (line 27) | class CompressionSolver(base.StandardSolver):
method __init__ (line 34) | def __init__(self, cfg: omegaconf.DictConfig):
method best_metric_name (line 55) | def best_metric_name(self) -> tp.Optional[str]:
method build_model (line 59) | def build_model(self):
method build_dataloaders (line 68) | def build_dataloaders(self):
method show (line 72) | def show(self):
method run_step (line 83) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
method run_epoch (line 176) | def run_epoch(self):
method evaluate (line 183) | def evaluate(self):
method generate (line 213) | def generate(self):
method load_from_pretrained (line 236) | def load_from_pretrained(self, name: str) -> dict:
method model_from_checkpoint (line 269) | def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
method wrapped_model_from_checkpoint (line 304) | def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
function evaluate_audio_reconstruction (line 320) | def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor,...
FILE: audiocraft/solvers/diffusion.py
class PerStageMetrics (line 25) | class PerStageMetrics:
method __init__ (line 30) | def __init__(self, num_steps: int, num_stages: int = 4):
method __call__ (line 34) | def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
class DataProcess (line 53) | class DataProcess:
method __init__ (line 67) | def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, us...
method process_data (line 95) | def process_data(self, x, metric=False):
method inverse_process (line 107) | def inverse_process(self, x):
class DiffusionSolver (line 114) | class DiffusionSolver(base.StandardSolver):
method __init__ (line 122) | def __init__(self, cfg: omegaconf.DictConfig):
method best_metric_name (line 155) | def best_metric_name(self) -> tp.Optional[str]:
method get_condition (line 162) | def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
method build_model (line 168) | def build_model(self):
method build_dataloaders (line 178) | def build_dataloaders(self):
method show (line 182) | def show(self):
method run_step (line 186) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
method run_epoch (line 215) | def run_epoch(self):
method evaluate (line 223) | def evaluate(self):
method regenerate (line 253) | def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] =...
method generate (line 262) | def generate(self):
FILE: audiocraft/solvers/jasco.py
class JascoSolver (line 19) | class JascoSolver(musicgen.MusicGenSolver):
method __init__ (line 25) | def __init__(self, cfg: DictConfig):
method build_model (line 39) | def build_model(self) -> None:
method _get_latents (line 55) | def _get_latents(self, audio):
method _prepare_latents_and_attributes (line 60) | def _prepare_latents_and_attributes(
method _normalized_latents (line 104) | def _normalized_latents(self, latents: torch.Tensor) -> torch.Tensor:
method _unnormalized_latents (line 108) | def _unnormalized_latents(self, latents: torch.Tensor) -> torch.Tensor:
method _z (line 112) | def _z(self, z_0: torch.Tensor, z_1: torch.Tensor, t: torch.Tensor, si...
method _vector_field (line 116) | def _vector_field(self, z_0: torch.Tensor, z_1: torch.Tensor, sigma_mi...
method _compute_loss (line 121) | def _compute_loss(self, t: torch.Tensor, v_theta: torch.Tensor, v: tor...
method run_step (line 134) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg...
method _decode_latents (line 216) | def _decode_latents(self, latents):
method run_generate_step (line 220) | def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[Segm...
FILE: audiocraft/solvers/magnet.py
class MagnetSolver (line 21) | class MagnetSolver(musicgen.MusicGenSolver):
method __init__ (line 25) | def __init__(self, cfg: DictConfig):
method build_model (line 47) | def build_model(self) -> None:
method _calc_mean_maskrate_to_u_LUT (line 53) | def _calc_mean_maskrate_to_u_LUT(self, T: int):
method _non_spans_mask (line 87) | def _non_spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, de...
method _spans_mask (line 102) | def _spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device...
method _get_mask (line 127) | def _get_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: ...
method _compute_cross_entropy_magnet (line 143) | def _compute_cross_entropy_magnet(self, logits: torch.Tensor,
method run_step (line 172) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg...
class AudioMagnetSolver (line 271) | class AudioMagnetSolver(MagnetSolver):
FILE: audiocraft/solvers/musicgen.py
class MusicGenSolver (line 32) | class MusicGenSolver(base.StandardSolver):
method __init__ (line 39) | def __init__(self, cfg: omegaconf.DictConfig):
method get_eval_solver_from_sig (line 66) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
method get_formatter (line 102) | def get_formatter(self, stage_name: str) -> flashy.Formatter:
method best_metric_name (line 111) | def best_metric_name(self) -> tp.Optional[str]:
method initialize_optimization (line 114) | def initialize_optimization(self) -> None:
method build_model (line 140) | def build_model(self) -> None:
method build_dataloaders (line 171) | def build_dataloaders(self) -> None:
method show (line 175) | def show(self) -> None:
method load_state_dict (line 182) | def load_state_dict(self, state: dict) -> None:
method load_from_pretrained (line 209) | def load_from_pretrained(self, name: str):
method _compute_cross_entropy (line 219) | def _compute_cross_entropy(
method _get_audio_tokens (line 253) | def _get_audio_tokens(self, audio: torch.Tensor):
method _prepare_tokens_and_attributes (line 259) | def _prepare_tokens_and_attributes(
method run_step (line 363) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg...
method run_generate_step (line 445) | def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[Segm...
method generate_audio (line 511) | def generate_audio(self) -> dict:
method generate (line 611) | def generate(self) -> dict:
method run_epoch (line 617) | def run_epoch(self):
method train (line 623) | def train(self):
method evaluate_audio_generation (line 636) | def evaluate_audio_generation(self) -> dict:
method evaluate (line 741) | def evaluate(self) -> dict:
FILE: audiocraft/solvers/watermark.py
function get_encodec_audio_effect (line 45) | def get_encodec_audio_effect(encodec_cfg: DictConfig, sr: int) -> tp.Dict:
function random_message (line 69) | def random_message(nbits: int, batch_size: int) -> torch.Tensor:
class WatermarkSolver (line 76) | class WatermarkSolver(base.StandardSolver):
method __init__ (line 79) | def __init__(self, cfg: DictConfig):
method _init_losses (line 93) | def _init_losses(self):
method _init_augmentations (line 133) | def _init_augmentations(self):
method best_metric_name (line 162) | def best_metric_name(self) -> tp.Optional[str]:
method build_model (line 166) | def build_model(self):
method build_dataloaders (line 176) | def build_dataloaders(self):
method show (line 180) | def show(self):
method crop (line 185) | def crop(
method run_step (line 251) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
method run_epoch (line 393) | def run_epoch(self):
method evaluate (line 400) | def evaluate(self) -> dict:
method generate (line 533) | def generate(self):
method load_from_pretrained (line 576) | def load_from_pretrained(self, name: str) -> dict:
method model_from_checkpoint (line 580) | def model_from_checkpoint(
function evaluate_localizations (line 617) | def evaluate_localizations(predictions, true_predictions, name):
function evaluate_augmentations (line 633) | def evaluate_augmentations(
function evaluate_audio_watermark (line 654) | def evaluate_audio_watermark(
function tensor_pesq (line 672) | def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int):
function compute_accuracy (line 677) | def compute_accuracy(positive, negative):
function compute_FPR (line 685) | def compute_FPR(negative):
function compute_FNR (line 691) | def compute_FNR(positive):
function _bit_acc (line 697) | def _bit_acc(decoded, original):
function compute_bit_acc (line 702) | def compute_bit_acc(positive, original, mask=None):
FILE: audiocraft/train.py
function resolve_config_dset_paths (line 30) | def resolve_config_dset_paths(cfg):
function get_solver (line 38) | def get_solver(cfg):
function get_solver_from_xp (line 52) | def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, ...
function get_solver_from_sig (line 97) | def get_solver_from_sig(sig: str, *args, **kwargs):
function init_seed_and_system (line 105) | def init_seed_and_system(cfg):
function main (line 131) | def main(cfg):
FILE: audiocraft/utils/audio_effects.py
function select_audio_effects (line 24) | def select_audio_effects(
function get_audio_effects (line 84) | def get_audio_effects(cfg: omegaconf.DictConfig):
function audio_effect_return (line 99) | def audio_effect_return(
function generate_pink_noise (line 109) | def generate_pink_noise(length: int) -> torch.Tensor:
function compress_with_encodec (line 121) | def compress_with_encodec(
function apply_compression_skip_grad (line 146) | def apply_compression_skip_grad(tensor: torch.Tensor, compression_fn, **...
class AudioEffects (line 177) | class AudioEffects:
method speed (line 179) | def speed(
method updownresample (line 206) | def updownresample(
method echo (line 223) | def echo(
method random_noise (line 278) | def random_noise(
method pink_noise (line 289) | def pink_noise(
method lowpass_filter (line 302) | def lowpass_filter(
method highpass_filter (line 315) | def highpass_filter(
method bandpass_filter (line 328) | def bandpass_filter(
method smooth (line 358) | def smooth(
method boost_audio (line 390) | def boost_audio(
method duck_audio (line 399) | def duck_audio(
method identity (line 408) | def identity(
method mp3_compression (line 414) | def mp3_compression(
method aac_compression (line 436) | def aac_compression(
FILE: audiocraft/utils/autocast.py
class TorchAutocast (line 10) | class TorchAutocast:
method __init__ (line 21) | def __init__(self, enabled: bool, *args, **kwargs):
method __enter__ (line 24) | def __enter__(self):
method __exit__ (line 37) | def __exit__(self, *args, **kwargs):
FILE: audiocraft/utils/best_state.py
class BestStateDictManager (line 21) | class BestStateDictManager(flashy.state.StateDictSource):
method __init__ (line 36) | def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
method _get_parameter_ids (line 43) | def _get_parameter_ids(self, state_dict):
method _validate_no_parameter_ids_overlap (line 46) | def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
method update (line 53) | def update(self, name: str, source: flashy.state.StateDictSource):
method register (line 58) | def register(self, name: str, source: flashy.state.StateDictSource):
method state_dict (line 75) | def state_dict(self) -> flashy.state.StateDict:
method load_state_dict (line 78) | def load_state_dict(self, state: flashy.state.StateDict):
FILE: audiocraft/utils/cache.py
function get_full_embed (line 24) | def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device...
class EmbeddingCache (line 39) | class EmbeddingCache:
method __init__ (line 60) | def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[s...
method _get_cache_path (line 79) | def _get_cache_path(self, path: tp.Union[Path, str]):
method _get_full_embed_from_cache (line 85) | def _get_full_embed_from_cache(cache: Path):
method get_embed_from_cache (line 94) | def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> tor...
method populate_embed_cache (line 124) | def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
class CachedBatchWriter (line 161) | class CachedBatchWriter:
method __init__ (line 180) | def __init__(self, cache_folder: Path):
method start_epoch (line 185) | def start_epoch(self, epoch: int):
method _get_zip_path (line 193) | def _get_zip_path(cache_folder: Path, epoch: int, index: int):
method _zip_path (line 197) | def _zip_path(self):
method save (line 201) | def save(self, *content):
class CachedBatchLoader (line 224) | class CachedBatchLoader:
method __init__ (line 237) | def __init__(self, cache_folder: Path, batch_size: int,
method __len__ (line 246) | def __len__(self):
method start_epoch (line 250) | def start_epoch(self, epoch: int):
method _zip_path (line 255) | def _zip_path(self, index: int):
method _load_one (line 259) | def _load_one(self, index: int):
method __iter__ (line 297) | def __iter__(self):
FILE: audiocraft/utils/checkpoint.py
class CheckpointSource (line 22) | class CheckpointSource(Enum):
function checkpoint_name (line 28) | def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int...
function is_sharded_checkpoint (line 51) | def is_sharded_checkpoint(path: Path) -> bool:
function resolve_checkpoint_path (line 56) | def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.O...
function load_checkpoint (line 87) | def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> ...
function save_checkpoint (line 98) | def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bo...
function flush_stale_checkpoints (line 104) | def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optiona...
function check_sharded_checkpoint (line 125) | def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_pat...
function _safe_save_checkpoint (line 142) | def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_shard...
FILE: audiocraft/utils/cluster.py
class ClusterType (line 19) | class ClusterType(Enum):
function _guess_cluster_type (line 27) | def _guess_cluster_type() -> ClusterType:
function get_cluster_type (line 45) | def get_cluster_type(
function get_slurm_parameters (line 54) | def get_slurm_parameters(
FILE: audiocraft/utils/deadlock.py
class DeadlockDetect (line 18) | class DeadlockDetect:
method __init__ (line 19) | def __init__(self, use: bool = False, timeout: float = 120.):
method update (line 24) | def update(self, stage: str):
method __enter__ (line 28) | def __enter__(self):
method __exit__ (line 33) | def __exit__(self, exc_type, exc_val, exc_tb):
method _detector_thread (line 38) | def _detector_thread(self):
FILE: audiocraft/utils/export.py
function export_encodec (line 20) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Un...
function export_pretrained_compression_model (line 36) | def export_pretrained_compression_model(pretrained_encodec: str, out_fil...
function export_lm (line 61) | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[P...
FILE: audiocraft/utils/export_legacy.py
function _clean_lm_cfg (line 20) | def _clean_lm_cfg(cfg: DictConfig):
function export_encodec (line 41) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Un...
function export_lm (line 55) | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[P...
FILE: audiocraft/utils/notebook.py
function display_audio (line 17) | def display_audio(samples: torch.Tensor, sample_rate: int):
FILE: audiocraft/utils/profiler.py
class Profiler (line 17) | class Profiler:
method __init__ (line 20) | def __init__(self, module: torch.nn.Module, enabled: bool = False):
method step (line 28) | def step(self):
method __enter__ (line 32) | def __enter__(self):
method __exit__ (line 36) | def __exit__(self, exc_type, exc_value, exc_tb):
FILE: audiocraft/utils/samples/manager.py
class ReferenceSample (line 42) | class ReferenceSample:
class Sample (line 49) | class Sample:
method __hash__ (line 59) | def __hash__(self):
method audio (line 62) | def audio(self) -> tp.Tuple[torch.Tensor, int]:
method audio_prompt (line 65) | def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
method audio_reference (line 68) | def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
class SampleManager (line 72) | class SampleManager:
method __init__ (line 89) | def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = Fal...
method latest_epoch (line 98) | def latest_epoch(self):
method _load_samples (line 102) | def _load_samples(self):
method _load_sample (line 110) | def _load_sample(json_file: Path) -> Sample:
method _init_hash (line 126) | def _init_hash(self):
method _get_tensor_id (line 129) | def _get_tensor_id(self, tensor: torch.Tensor) -> str:
method _get_sample_id (line 134) | def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Ten...
method _store_audio (line 173) | def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: ...
method add_sample (line 196) | def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int ...
method add_samples (line 238) | def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
method get_samples (line 269) | def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_pr...
function slugify (line 305) | def slugify(value: tp.Any, allow_unicode: bool = False):
function _match_stable_samples (line 328) | def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp...
function _match_unstable_samples (line 343) | def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> ...
function get_samples_for_xps (line 358) | def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str,...
FILE: audiocraft/utils/utils.py
function model_hash (line 25) | def model_hash(model: torch.nn.Module) -> str:
function dict_from_config (line 35) | def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
function random_subset (line 48) | def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.ut...
function get_loader (line 57) | def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
function get_dataset_from_loader (line 80) | def get_dataset_from_loader(dataloader):
function multinomial (line 88) | def multinomial(input: torch.Tensor, num_samples: int, replacement=False...
function sample_top_k (line 108) | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
function sample_top_p (line 125) | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
class DummyPoolExecutor (line 144) | class DummyPoolExecutor:
class DummyResult (line 148) | class DummyResult:
method __init__ (line 149) | def __init__(self, func, *args, **kwargs):
method result (line 154) | def result(self):
method __init__ (line 157) | def __init__(self, workers, mp_context=None):
method submit (line 160) | def submit(self, func, *args, **kwargs):
method __enter__ (line 163) | def __enter__(self):
method __exit__ (line 166) | def __exit__(self, exc_type, exc_value, exc_tb):
function get_pool_executor (line 170) | def get_pool_executor(num_workers: int, mp_context=None):
function length_to_mask (line 174) | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = No...
function hash_trick (line 190) | def hash_trick(word: str, vocab_size: int) -> int:
function with_rank_rng (line 203) | def with_rank_rng(base_seed: int = 1234):
function collate (line 226) | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[to...
function copy_state (line 250) | def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
function swap_state (line 264) | def swap_state(model, state, **kwargs):
function warn_once (line 274) | def warn_once(logger, msg):
function is_jsonable (line 279) | def is_jsonable(x: tp.Any):
function load_clap_state_dict (line 288) | def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
function construct_frame_chords (line 300) | def construct_frame_chords(
FILE: demos/jasco_app.py
function _call_nostderr (line 32) | def _call_nostderr(*args, **kwargs):
function interrupt (line 45) | def interrupt():
class FileCleaner (line 50) | class FileCleaner:
method __init__ (line 51) | def __init__(self, file_lifetime: float = 3600):
method add (line 55) | def add(self, path: tp.Union[str, Path]):
method _cleanup (line 59) | def _cleanup(self):
function chords_string_to_list (line 73) | def chords_string_to_list(chords: str):
function load_model (line 85) | def load_model(version='facebook/jasco-chords-drums-400M'):
function _do_predictions (line 93) | def _do_predictions(texts, chords, melody_matrix, drum_prompt, progress=...
function predict_full (line 140) | def predict_full(model,
function ui_full (line 180) | def ui_full(launch_kwargs):
FILE: demos/magnet_app.py
function _call_nostderr (line 37) | def _call_nostderr(*args, **kwargs):
function interrupt (line 50) | def interrupt():
class FileCleaner (line 55) | class FileCleaner:
method __init__ (line 56) | def __init__(self, file_lifetime: float = 3600):
method add (line 60) | def add(self, path: tp.Union[str, Path]):
method _cleanup (line 64) | def _cleanup(self):
function make_waveform (line 78) | def make_waveform(*args, **kwargs):
function load_model (line 88) | def load_model(version='facebook/magnet-small-10secs'):
function _do_predictions (line 96) | def _do_predictions(texts, progress=False, gradio_progress=None, **gen_k...
function predict_batched (line 125) | def predict_batched(texts, melodies):
function predict_full (line 133) | def predict_full(model, model_path, text, temperature, topp,
function ui_full (line 175) | def ui_full(launch_kwargs):
FILE: demos/musicgen_app.py
function _call_nostderr (line 44) | def _call_nostderr(*args, **kwargs):
function interrupt (line 57) | def interrupt():
class FileCleaner (line 62) | class FileCleaner:
method __init__ (line 63) | def __init__(self, file_lifetime: float = 3600):
method add (line 67) | def add(self, path: tp.Union[str, Path]):
method _cleanup (line 71) | def _cleanup(self):
function make_waveform (line 84) | def make_waveform(*args, **kwargs):
function load_model (line 94) | def load_model(version='facebook/musicgen-melody'):
function load_diffusion (line 105) | def load_diffusion():
function _do_predictions (line 112) | def _do_predictions(texts, melodies, duration, progress=False, gradio_pr...
function predict_batched (line 174) | def predict_batched(texts, melodies):
function predict_full (line 182) | def predict_full(model, model_path, decoder, text, melody, duration, top...
function toggle_audio_src (line 230) | def toggle_audio_src(choice):
function toggle_diffusion (line 237) | def toggle_diffusion(choice):
function ui_full (line 244) | def ui_full(launch_kwargs):
function ui_batched (line 387) | def ui_batched(launch_kwargs):
FILE: demos/musicgen_style_app.py
function _call_nostderr (line 39) | def _call_nostderr(*args, **kwargs):
function interrupt (line 52) | def interrupt():
class FileCleaner (line 57) | class FileCleaner:
method __init__ (line 58) | def __init__(self, file_lifetime: float = 3600):
method add (line 62) | def add(self, path: tp.Union[str, Path]):
method _cleanup (line 66) | def _cleanup(self):
function make_waveform (line 79) | def make_waveform(*args, **kwargs):
function load_model (line 89) | def load_model(version='facebook/musicgen-style'):
function load_diffusion (line 100) | def load_diffusion():
function _do_predictions (line 107) | def _do_predictions(texts, melodies, duration, top_k, top_p, temperature...
function predict_full (line 164) | def predict_full(model, model_path, decoder, text, melody, duration, top...
function toggle_audio_src (line 220) | def toggle_audio_src(choice):
function toggle_diffusion (line 227) | def toggle_diffusion(choice):
function ui_full (line 234) | def ui_full(launch_kwargs):
FILE: scripts/chords/build_chord_maps.py
function parse_args (line 12) | def parse_args():
function get_chord_dict (line 25) | def get_chord_dict(chord_folder: str):
function get_predefined_chord_to_index_map (line 50) | def get_predefined_chord_to_index_map(path_to_chords_to_index_map: str):
FILE: scripts/chords/extract_chords.py
function parse_args (line 11) | def parse_args():
function save_to_db_cb (line 22) | def save_to_db_cb(tgt_dir: str):
FILE: scripts/mos.py
function normalize_path (line 43) | def normalize_path(path: Path):
function get_full_path (line 51) | def get_full_path(normalized_path: Path):
function get_signature (line 57) | def get_signature(xps: tp.List[str]):
function ensure_logged (line 63) | def ensure_logged(func):
function login (line 76) | def login():
function index (line 98) | def index():
function survey (line 135) | def survey(signature):
function audio (line 236) | def audio(path: str):
function mean (line 242) | def mean(x):
function std (line 246) | def std(x):
function results (line 253) | def results(signature):
FILE: scripts/resample_dataset.py
function read_txt_files (line 22) | def read_txt_files(path: tp.Union[str, Path]):
function read_egs_files (line 31) | def read_egs_files(path: tp.Union[str, Path]):
function process_dataset (line 45) | def process_dataset(args, n_shards: int, node_index: int, task_index: tp...
FILE: tests/adversarial/test_discriminators.py
class TestMultiPeriodDiscriminator (line 18) | class TestMultiPeriodDiscriminator:
method test_mpd_discriminator (line 20) | def test_mpd_discriminator(self):
class TestMultiScaleDiscriminator (line 33) | class TestMultiScaleDiscriminator:
method test_msd_discriminator (line 35) | def test_msd_discriminator(self):
class TestMultiScaleStftDiscriminator (line 49) | class TestMultiScaleStftDiscriminator:
method test_msstftd_discriminator (line 51) | def test_msstftd_discriminator(self):
FILE: tests/adversarial/test_losses.py
class TestAdversarialLoss (line 22) | class TestAdversarialLoss:
method test_adversarial_single_multidiscriminator (line 24) | def test_adversarial_single_multidiscriminator(self):
method test_adversarial_feat_loss (line 45) | def test_adversarial_feat_loss(self):
class TestGeneratorAdversarialLoss (line 65) | class TestGeneratorAdversarialLoss:
method test_hinge_generator_adv_loss (line 67) | def test_hinge_generator_adv_loss(self):
method test_mse_generator_adv_loss (line 76) | def test_mse_generator_adv_loss(self):
class TestDiscriminatorAdversarialLoss (line 88) | class TestDiscriminatorAdversarialLoss:
method _disc_loss (line 90) | def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.T...
method test_hinge_discriminator_adv_loss (line 97) | def test_hinge_discriminator_adv_loss(self):
method test_mse_discriminator_adv_loss (line 105) | def test_mse_discriminator_adv_loss(self):
class TestFeatureMatchingLoss (line 115) | class TestFeatureMatchingLoss:
method test_features_matching_loss_base (line 117) | def test_features_matching_loss_base(self):
method test_features_matching_loss_raises_exception (line 126) | def test_features_matching_loss_raises_exception(self):
method test_features_matching_loss_output (line 141) | def test_features_matching_loss_output(self):
FILE: tests/common_utils/temp_utils.py
class TempDirMixin (line 11) | class TempDirMixin:
method get_base_temp_dir (line 18) | def get_base_temp_dir(cls):
method tearDownClass (line 29) | def tearDownClass(cls):
method id (line 43) | def id(self):
method get_temp_path (line 46) | def get_temp_path(self, *paths):
method get_temp_dir (line 52) | def get_temp_dir(self, *paths):
FILE: tests/common_utils/wav_utils.py
function get_white_noise (line 14) | def get_white_noise(chs: int = 1, num_frames: int = 1):
function get_batch_white_noise (line 19) | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
function save_wav (line 24) | def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
FILE: tests/data/test_audio.py
class TestInfo (line 19) | class TestInfo(TempDirMixin):
method test_info_mp3 (line 21) | def test_info_mp3(self):
method _test_info_format (line 34) | def _test_info_format(self, ext: str):
method test_info_wav (line 48) | def test_info_wav(self):
method test_info_flac (line 51) | def test_info_flac(self):
method test_info_ogg (line 54) | def test_info_ogg(self):
method test_info_m4a (line 57) | def test_info_m4a(self):
class TestRead (line 63) | class TestRead(TempDirMixin):
method test_read_full_wav (line 65) | def test_read_full_wav(self):
method test_read_partial_wav (line 80) | def test_read_partial_wav(self):
method test_read_seek_time_wav (line 97) | def test_read_seek_time_wav(self):
method test_read_seek_time_wav_padded (line 116) | def test_read_seek_time_wav_padded(self):
class TestAvRead (line 139) | class TestAvRead(TempDirMixin):
method test_avread_seek_base (line 141) | def test_avread_seek_base(self):
method test_avread_seek_partial (line 159) | def test_avread_seek_partial(self):
method test_avread_seek_outofbound (line 178) | def test_avread_seek_outofbound(self):
method test_avread_seek_edge (line 193) | def test_avread_seek_edge(self):
class TestAudioWrite (line 212) | class TestAudioWrite(TempDirMixin):
method test_audio_write_wav (line 214) | def test_audio_write_wav(self):
FILE: tests/data/test_audio_dataset.py
class TestAudioMeta (line 31) | class TestAudioMeta(TempDirMixin):
method test_get_audio_meta (line 33) | def test_get_audio_meta(self):
method test_save_audio_meta (line 49) | def test_save_audio_meta(self):
method test_load_audio_meta (line 65) | def test_load_audio_meta(self):
class TestAudioDataset (line 90) | class TestAudioDataset(TempDirMixin):
method _create_audio_files (line 92) | def _create_audio_files(self,
method _create_audio_dataset (line 114) | def _create_audio_dataset(self,
method test_dataset_full (line 135) | def test_dataset_full(self):
method test_dataset_segment (line 152) | def test_dataset_segment(self):
method test_dataset_equal_audio_and_segment_durations (line 170) | def test_dataset_equal_audio_and_segment_durations(self):
method test_dataset_samples (line 192) | def test_dataset_samples(self):
method test_dataset_return_info (line 218) | def test_dataset_return_info(self):
method test_dataset_return_info_no_segment_duration (line 240) | def test_dataset_return_info_no_segment_duration(self):
method test_dataset_collate_fn (line 260) | def test_dataset_collate_fn(self):
method test_dataset_with_meta_collate_fn (line 280) | def test_dataset_with_meta_collate_fn(self, segment_duration):
method test_sample_with_weight (line 308) | def test_sample_with_weight(self, segment_duration, sample_on_weight, ...
method test_meta_duration_filter_all (line 333) | def test_meta_duration_filter_all(self):
method test_meta_duration_filter_long (line 345) | def test_meta_duration_filter_long(self):
FILE: tests/data/test_audio_utils.py
class TestConvertAudioChannels (line 22) | class TestConvertAudioChannels:
method test_convert_audio_channels_downmix (line 24) | def test_convert_audio_channels_downmix(self):
method test_convert_audio_channels_nochange (line 30) | def test_convert_audio_channels_nochange(self):
method test_convert_audio_channels_upmix (line 36) | def test_convert_audio_channels_upmix(self):
method test_convert_audio_channels_upmix_error (line 42) | def test_convert_audio_channels_upmix_error(self):
class TestConvertAudio (line 49) | class TestConvertAudio:
method test_convert_audio_channels_downmix (line 51) | def test_convert_audio_channels_downmix(self):
method test_convert_audio_channels_upmix (line 58) | def test_convert_audio_channels_upmix(self):
method test_convert_audio_upsample (line 65) | def test_convert_audio_upsample(self):
method test_convert_audio_resample (line 74) | def test_convert_audio_resample(self):
method test_convert_pcm (line 83) | def test_convert_pcm(self):
class TestNormalizeAudio (line 92) | class TestNormalizeAudio:
method test_clip_wav (line 94) | def test_clip_wav(self):
method test_normalize_audio_clip (line 101) | def test_normalize_audio_clip(self):
method test_normalize_audio_rms (line 108) | def test_normalize_audio_rms(self):
method test_normalize_audio_peak (line 115) | def test_normalize_audio_peak(self):
FILE: tests/losses/test_losses.py
function test_mel_l1_loss (line 23) | def test_mel_l1_loss():
function test_msspec_loss (line 37) | def test_msspec_loss():
function test_mrstft_loss (line 51) | def test_mrstft_loss():
function test_sisnr_loss (line 62) | def test_sisnr_loss():
function test_stft_loss (line 73) | def test_stft_loss():
function test_wm_loss (line 84) | def test_wm_loss():
function test_loudness_loss (line 96) | def test_loudness_loss():
FILE: tests/metrics/test_pesq.py
function tensor_pesq (line 14) | def tensor_pesq(y_pred: torch.Tensor, y: torch.Tensor, sr: int):
class TestPesq (line 30) | class TestPesq(TempDirMixin):
method test (line 32) | def test(self):
FILE: tests/models/test_audiogen.py
class TestAudioGenModel (line 13) | class TestAudioGenModel:
method get_audiogen (line 14) | def get_audiogen(self):
method test_base (line 19) | def test_base(self):
method test_generate_continuation (line 25) | def test_generate_continuation(self):
method test_generate (line 41) | def test_generate(self):
method test_generate_long (line 47) | def test_generate_long(self):
FILE: tests/models/test_encodec_model.py
class TestEncodecModel (line 17) | class TestEncodecModel:
method _create_encodec_model (line 19) | def _create_encodec_model(self,
method test_model (line 37) | def test_model(self):
method test_model_renorm (line 48) | def test_model_renorm(self):
FILE: tests/models/test_multibanddiffusion.py
class TestMBD (line 18) | class TestMBD:
method _create_mbd (line 20) | def _create_mbd(self,
method test_model (line 43) | def test_model(self):
FILE: tests/models/test_musicgen.py
class TestMusicGenModel (line 13) | class TestMusicGenModel:
method get_musicgen (line 14) | def get_musicgen(self):
method test_base (line 19) | def test_base(self):
method test_generate_unconditional (line 25) | def test_generate_unconditional(self):
method test_generate_continuation (line 30) | def test_generate_continuation(self):
method test_generate (line 46) | def test_generate(self):
method test_generate_long (line 52) | def test_generate_long(self):
method test_generate_two_step_cfg (line 60) | def test_generate_two_step_cfg(self):
FILE: tests/models/test_watermark.py
class TestWatermarkModel (line 13) | class TestWatermarkModel:
method test_base (line 15) | def test_base(self):
FILE: tests/modules/test_activations.py
class TestActivations (line 13) | class TestActivations:
method test_custom_glu_calculation (line 14) | def test_custom_glu_calculation(self):
FILE: tests/modules/test_codebooks_patterns.py
class TestParallelPatternProvider (line 18) | class TestParallelPatternProvider:
method test_get_pattern (line 22) | def test_get_pattern(self, n_q: int, timesteps: int):
method test_pattern_content (line 30) | def test_pattern_content(self, n_q: int, timesteps: int):
method test_pattern_max_delay (line 40) | def test_pattern_max_delay(self, n_q: int, timesteps: int):
class TestDelayedPatternProvider (line 47) | class TestDelayedPatternProvider:
method test_get_pattern (line 51) | def test_get_pattern(self, n_q: int, timesteps: int):
method test_pattern_content (line 65) | def test_pattern_content(self, n_q: int, timesteps: int):
method test_pattern_max_delay (line 75) | def test_pattern_max_delay(self, timesteps: int, delay: list):
class TestUnrolledPatternProvider (line 82) | class TestUnrolledPatternProvider:
method test_get_pattern (line 87) | def test_get_pattern(self, timesteps: int, flattening: list, delays: l...
method test_pattern_max_delay (line 97) | def test_pattern_max_delay(self, timesteps: int, flattening: list, del...
class TestPattern (line 105) | class TestPattern:
method ref_build_pattern_sequence (line 107) | def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern...
method ref_revert_pattern_sequence (line 121) | def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Patter...
method ref_revert_pattern_logits (line 134) | def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern,...
method _get_pattern_providers (line 149) | def _get_pattern_providers(self, n_q: int):
method test_build_pattern_sequence (line 173) | def test_build_pattern_sequence(self, n_q: int, timesteps: int):
method test_revert_pattern_sequence (line 205) | def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
method test_revert_pattern_logits (line 228) | def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: i...
FILE: tests/modules/test_conv.py
function test_get_extra_padding_for_conv1d (line 25) | def test_get_extra_padding_for_conv1d():
function test_pad1d_zeros (line 30) | def test_pad1d_zeros():
function test_pad1d_reflect (line 52) | def test_pad1d_reflect():
function test_unpad1d (line 74) | def test_unpad1d():
class TestNormConv1d (line 96) | class TestNormConv1d:
method test_norm_conv1d_modules (line 98) | def test_norm_conv1d_modules(self):
class TestNormConvTranspose1d (line 123) | class TestNormConvTranspose1d:
method test_normalizations (line 125) | def test_normalizations(self):
class TestStreamableConv1d (line 151) | class TestStreamableConv1d:
method get_streamable_conv1d_output_length (line 153) | def get_streamable_conv1d_output_length(self, length, kernel_size, str...
method test_streamable_conv1d (line 160) | def test_streamable_conv1d(self):
class TestStreamableConvTranspose1d (line 176) | class TestStreamableConvTranspose1d:
method get_streamable_convtr1d_output_length (line 178) | def get_streamable_convtr1d_output_length(self, length, kernel_size, s...
method test_streamable_convtr1d (line 182) | def test_streamable_convtr1d(self):
FILE: tests/modules/test_lstm.py
class TestStreamableLSTM (line 13) | class TestStreamableLSTM:
method test_lstm (line 15) | def test_lstm(self):
method test_lstm_skip (line 25) | def test_lstm_skip(self):
FILE: tests/modules/test_rope.py
function test_rope (line 13) | def test_rope():
function test_rope_io_dtypes (line 26) | def test_rope_io_dtypes():
function test_transformer_with_rope (line 50) | def test_transformer_with_rope():
function test_rope_streaming (line 66) | def test_rope_streaming():
function test_rope_streaming_past_context (line 94) | def test_rope_streaming_past_context():
function test_rope_memory_efficient (line 124) | def test_rope_memory_efficient():
function test_rope_with_xpos (line 145) | def test_rope_with_xpos():
function test_positional_scale (line 158) | def test_positional_scale():
FILE: tests/modules/test_seanet.py
class TestSEANetModel (line 16) | class TestSEANetModel:
method test_base (line 18) | def test_base(self):
method test_causal (line 28) | def test_causal(self):
method test_conv_skip_connection (line 38) | def test_conv_skip_connection(self):
method test_seanet_encoder_decoder_final_act (line 48) | def test_seanet_encoder_decoder_final_act(self):
method _check_encoder_blocks_norm (line 58) | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable...
method test_encoder_disable_norm (line 70) | def test_encoder_disable_norm(self):
method _check_decoder_blocks_norm (line 79) | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable...
method test_decoder_disable_norm (line 94) | def test_decoder_disable_norm(self):
method test_disable_norm_raises_exception (line 103) | def test_disable_norm_raises_exception(self):
FILE: tests/modules/test_transformer.py
function test_transformer_causal_streaming (line 16) | def test_transformer_causal_streaming():
function test_transformer_vs_pytorch (line 52) | def test_transformer_vs_pytorch():
function test_streaming_api (line 71) | def test_streaming_api():
function test_memory_efficient (line 88) | def test_memory_efficient():
function test_attention_as_float32 (line 108) | def test_attention_as_float32():
function test_streaming_memory_efficient (line 134) | def test_streaming_memory_efficient():
function test_cross_attention (line 164) | def test_cross_attention():
function test_cross_attention_compat (line 192) | def test_cross_attention_compat():
function test_repeat_kv (line 224) | def test_repeat_kv():
function test_qk_layer_norm (line 241) | def test_qk_layer_norm():
FILE: tests/quantization/test_vq.py
class TestResidualVectorQuantizer (line 12) | class TestResidualVectorQuantizer:
method test_rvq (line 14) | def test_rvq(self):
FILE: tests/utils/test_audio_effects.py
class TestAudioEffect (line 15) | class TestAudioEffect:
method audio_effects (line 19) | def audio_effects(self):
method test_select_empty_effects (line 86) | def test_select_empty_effects(self):
method test_select_wrong_strategy (line 90) | def test_select_wrong_strategy(self):
method test_selection (line 97) | def test_selection(self, audio_effects):
Condensed preview — 296 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (5,854K chars).
[
{
"path": ".github/actions/audiocraft_build/action.yml",
"chars": 869,
"preview": "name: audiocraft_build\ndescription: 'Build audiocraft env.'\nruns:\n using: \"composite\"\n steps:\n - uses: actions/setup-"
},
{
"path": ".github/workflows/audiocraft_docs.yml",
"chars": 735,
"preview": "name: audiocraft_docs\non:\n push:\n branches: [ main ]\n\njobs:\n run_docs:\n name: Run docs\n runs-on: ubuntu-lates"
},
{
"path": ".github/workflows/audiocraft_linter.yml",
"chars": 348,
"preview": "name: audiocraft_linter\non:\n push:\n branches: [ main ]\n pull_request:\n branches: [ main, audiocraft_pub_main ]\n\n"
},
{
"path": ".github/workflows/audiocraft_tests.yml",
"chars": 480,
"preview": "name: audiocraft_tests\non:\n push:\n branches: [ main ]\n pull_request:\n branches: [ main, audiocraft_pub_main ]\n\nj"
},
{
"path": ".gitignore",
"chars": 623,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# macOS dir files\n.DS_Sto"
},
{
"path": "CHANGELOG.md",
"chars": 3338,
"preview": "# Changelog\n\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Change"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 3535,
"preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
},
{
"path": "CONTRIBUTING.md",
"chars": 1377,
"preview": "# Contributing to AudioCraft\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull"
},
{
"path": "LICENSE",
"chars": 1088,
"preview": "MIT License\n\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nPermission is hereby granted, free of charge, to any pe"
},
{
"path": "LICENSE_weights",
"chars": 19329,
"preview": "Attribution-NonCommercial 4.0 International\n\n=======================================================================\n\nCr"
},
{
"path": "MANIFEST.in",
"chars": 368,
"preview": "include Makefile\ninclude LICENSE\ninclude LICENSE_weights\ninclude *.md\ninclude *.ini\ninclude requirements.txt\ninclude aud"
},
{
"path": "Makefile",
"chars": 1823,
"preview": "INTEG=AUDIOCRAFT_DORA_DIR=\"/tmp/magma_$(USER)\" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epo"
},
{
"path": "README.md",
"chars": 4833,
"preview": "# AudioCraft\n\n![linter "
},
{
"path": "audiocraft/__init__.py",
"chars": 1387,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/__init__.py",
"chars": 570,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/discriminators/__init__.py",
"chars": 346,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/discriminators/base.py",
"chars": 894,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/discriminators/mpd.py",
"chars": 4176,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/discriminators/msd.py",
"chars": 5926,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/discriminators/msstftd.py",
"chars": 6331,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/adversarial/losses.py",
"chars": 9126,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/__init__.py",
"chars": 411,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/audio.py",
"chars": 13766,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/audio_dataset.py",
"chars": 25464,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/audio_utils.py",
"chars": 14951,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/info_audio_dataset.py",
"chars": 3902,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/jasco_dataset.py",
"chars": 13594,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/music_dataset.py",
"chars": 11575,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/sound_dataset.py",
"chars": 13381,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/data/zip.py",
"chars": 2202,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/environment.py",
"chars": 6741,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/__init__.py",
"chars": 216,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/_base_explorers.py",
"chars": 2639,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/audiogen/__init__.py",
"chars": 220,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/audiogen/audiogen_base_16khz.py",
"chars": 776,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py",
"chars": 2483,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/compression/__init__.py",
"chars": 219,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/compression/_explorers.py",
"chars": 1601,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/compression/debug.py",
"chars": 1117,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/compression/encodec_audiogen_16khz.py",
"chars": 1100,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/compression/encodec_base_24khz.py",
"chars": 956,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/compression/encodec_musicgen_32khz.py",
"chars": 1262,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/diffusion/4_bands_base_32khz.py",
"chars": 1073,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/diffusion/__init__.py",
"chars": 221,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/diffusion/_explorers.py",
"chars": 2066,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/magnet/__init__.py",
"chars": 218,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/magnet/audio_magnet_16khz.py",
"chars": 1043,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/magnet/audio_magnet_pretrained_16khz_eval.py",
"chars": 2652,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/magnet/magnet_32khz.py",
"chars": 1482,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/magnet/magnet_pretrained_32khz_eval.py",
"chars": 3266,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/__init__.py",
"chars": 220,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/_explorers.py",
"chars": 3092,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_base_32khz.py",
"chars": 1413,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_base_cached_32khz.py",
"chars": 2311,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_clapemb_32khz.py",
"chars": 1193,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_melody_32khz.py",
"chars": 2251,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py",
"chars": 3880,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py",
"chars": 2139,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/musicgen/musicgen_style_32khz.py",
"chars": 1016,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/watermarking/__init__.py",
"chars": 224,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/watermarking/_explorers.py",
"chars": 3702,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/watermarking/audioseal.py",
"chars": 802,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/grids/watermarking/kbits.py",
"chars": 2787,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/__init__.py",
"chars": 687,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/balancer.py",
"chars": 6612,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/loudnessloss.py",
"chars": 7563,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/sisnr.py",
"chars": 3263,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/specloss.py",
"chars": 6531,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/stftloss.py",
"chars": 8202,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/losses/wmloss.py",
"chars": 4249,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/__init__.py",
"chars": 592,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/chroma_cosinesim.py",
"chars": 3674,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/clap_consistency.py",
"chars": 4525,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/fad.py",
"chars": 17721,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/kld.py",
"chars": 10211,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/miou.py",
"chars": 1575,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/pesq.py",
"chars": 1549,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/rvm.py",
"chars": 6107,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/metrics/visqol.py",
"chars": 9694,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/__init__.py",
"chars": 770,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/audiogen.py",
"chars": 4347,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/builders.py",
"chars": 15893,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/encodec.py",
"chars": 18020,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/flow_matching.py",
"chars": 23409,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/genmodel.py",
"chars": 12331,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/jasco.py",
"chars": 16118,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/lm.py",
"chars": 30735,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/lm_magnet.py",
"chars": 25366,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/loaders.py",
"chars": 9428,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/magnet.py",
"chars": 4400,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/multibanddiffusion.py",
"chars": 8737,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/musicgen.py",
"chars": 17295,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/unet.py",
"chars": 8340,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/models/watermark.py",
"chars": 3560,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/__init__.py",
"chars": 586,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/activations.py",
"chars": 3266,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/chroma.py",
"chars": 3023,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/codebooks_patterns.py",
"chars": 28228,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/conditioners.py",
"chars": 81744,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/conv.py",
"chars": 10496,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/diffusion_schedule.py",
"chars": 12018,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/jasco_conditioners.py",
"chars": 15257,
"preview": "import torch\nimport typing as tp\nfrom itertools import chain\nfrom pathlib import Path\nfrom torch import nn\nfrom .conditi"
},
{
"path": "audiocraft/modules/lstm.py",
"chars": 759,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/rope.py",
"chars": 5649,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/seanet.py",
"chars": 13868,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/streaming.py",
"chars": 4494,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/transformer.py",
"chars": 37408,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/modules/unet_transformer.py",
"chars": 3190,
"preview": "import torch\nimport typing as tp\nfrom .transformer import StreamingTransformer, create_sin_embedding\n\n\nclass UnetTransfo"
},
{
"path": "audiocraft/modules/watermark.py",
"chars": 4029,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/__init__.py",
"chars": 638,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/cosine_lr_scheduler.py",
"chars": 1730,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/dadam.py",
"chars": 8910,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/ema.py",
"chars": 3196,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/fsdp.py",
"chars": 7819,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/inverse_sqrt_lr_scheduler.py",
"chars": 1390,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/linear_warmup_lr_scheduler.py",
"chars": 1272,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/optim/polynomial_decay_lr_scheduler.py",
"chars": 2012,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/py.typed",
"chars": 0,
"preview": ""
},
{
"path": "audiocraft/quantization/__init__.py",
"chars": 329,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/quantization/base.py",
"chars": 3314,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/quantization/core_vq.py",
"chars": 14562,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/quantization/vq.py",
"chars": 4654,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/__init__.py",
"chars": 574,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/audiogen.py",
"chars": 655,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/base.py",
"chars": 31355,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/builders.py",
"chars": 14518,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/compression.py",
"chars": 14774,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/diffusion.py",
"chars": 11336,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/jasco.py",
"chars": 12449,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/magnet.py",
"chars": 12387,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/musicgen.py",
"chars": 37651,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/solvers/watermark.py",
"chars": 28492,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/train.py",
"chars": 6724,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/__init__.py",
"chars": 215,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/audio_effects.py",
"chars": 17113,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/autocast.py",
"chars": 1377,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/best_state.py",
"chars": 3694,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/cache.py",
"chars": 14356,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/checkpoint.py",
"chars": 6129,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/cluster.py",
"chars": 2044,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/deadlock.py",
"chars": 1710,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/export.py",
"chars": 2677,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/export_legacy.py",
"chars": 2403,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/notebook.py",
"chars": 885,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/profiler.py",
"chars": 1209,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/samples/__init__.py",
"chars": 198,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/samples/manager.py",
"chars": 19385,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "audiocraft/utils/utils.py",
"chars": 11556,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
},
{
"path": "config/augmentations/default.yaml",
"chars": 1828,
"preview": "# @package __global__\n\naudio_effects:\n speed:\n sample_rate: ${sample_rate}\n speed_range: [0.8, 1.2]\n updownresam"
},
{
"path": "config/conditioner/chords2music.yaml",
"chars": 696,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.3 # dropout of all conditions\n inference_coef: "
},
{
"path": "config/conditioner/chroma2music.yaml",
"chars": 817,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.2\n inference_coef: 3.0\n\nattribute_dropout:\n arg"
},
{
"path": "config/conditioner/clapemb2music.yaml",
"chars": 922,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.3\n inference_coef: 3.0\n\nattribute_dropout:\n tex"
},
{
"path": "config/conditioner/drums2music.yaml",
"chars": 765,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.3 # dropout of all conditions\n inference_coef: "
},
{
"path": "config/conditioner/jasco_chords_drums.yaml",
"chars": 929,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.3 # dropout of all conditions\n inference_coef: "
},
{
"path": "config/conditioner/jasco_chords_drums_melody.yaml",
"chars": 1135,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.2 # dropout of all conditions\n inference_coef: "
},
{
"path": "config/conditioner/none.yaml",
"chars": 239,
"preview": "# @package __global__\n\n# No conditioning\n\nclassifier_free_guidance:\n training_dropout: 0\n inference_coef: 1\n\nattribute"
},
{
"path": "config/conditioner/style2music.yaml",
"chars": 1228,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.1\n inference_coef: 3.0\n\nattribute_dropout:\n arg"
},
{
"path": "config/conditioner/text2music.yaml",
"chars": 496,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.3\n inference_coef: 3.0\n\nattribute_dropout: {}\n\nf"
},
{
"path": "config/conditioner/text2sound.yaml",
"chars": 411,
"preview": "# @package __global__\n\nclassifier_free_guidance:\n training_dropout: 0.1\n inference_coef: 3.0\n\nattribute_dropout: {}\n\nf"
},
{
"path": "config/config.yaml",
"chars": 2733,
"preview": "# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft\n# Please don't update this file d"
},
{
"path": "config/dset/audio/audiocaps_16khz.yaml",
"chars": 281,
"preview": "# @package __global__\n\n# AudioCaps dataset\ndatasource:\n max_sample_rate: 16000\n max_channels: 1\n\n train: null # only"
},
{
"path": "config/dset/audio/default.yaml",
"chars": 138,
"preview": "# @package __global__\n\ndatasource:\n max_sample_rate: ???\n max_channels: ???\n\n train: ???\n valid: ???\n evaluate: ???"
},
{
"path": "config/dset/audio/example.yaml",
"chars": 169,
"preview": "# @package __global__\n\ndatasource:\n max_sample_rate: 44100\n max_channels: 2\n\n train: egs/example\n valid: egs/example"
},
{
"path": "config/dset/audio/musiccaps_32khz.yaml",
"chars": 358,
"preview": "# @package __global__\n\n# total samples obtained from MusicCaps = 5469\n# (out of 5521 due to AudioSet corrupted samples)\n"
},
{
"path": "config/dset/default.yaml",
"chars": 300,
"preview": "# @package __global__\n\n# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft\n# Please don"
},
{
"path": "config/dset/internal/music_10k_32khz.yaml",
"chars": 338,
"preview": "# @package __global__\n\n# high quality music dataset with no artist overlap between splits\ndatasource:\n max_sample_rate:"
},
{
"path": "config/dset/internal/music_400k_32khz.yaml",
"chars": 275,
"preview": "# @package __global__\n\ndatasource:\n max_sample_rate: 32000\n max_channels: 1\n\n train: egs/music/music_400k_32khz/train"
},
{
"path": "config/dset/internal/sounds_16khz.yaml",
"chars": 344,
"preview": "# @package __global__\n\n# environmental sounds dataset compiling all datasets\n# with applied filters on tags\ndatasource:\n"
},
{
"path": "config/model/encodec/default.yaml",
"chars": 1112,
"preview": "# @package __global__\n\ncompression_model: encodec\n\nencodec:\n autoencoder: seanet\n quantizer: rvq\n sample_rate: ${samp"
},
{
"path": "config/model/encodec/encodec_base_causal.yaml",
"chars": 112,
"preview": "# @package __global__\n\ndefaults:\n - encodec/default\n\nencodec:\n causal: true\n\nrvq:\n n_q: 32\n q_dropout: true\n"
},
{
"path": "config/model/encodec/encodec_large_nq4_s320.yaml",
"chars": 161,
"preview": "# @package __global__\n\ndefaults:\n - encodec/default\n\nseanet:\n # default ratios are [8, 5, 4, 2]\n n_filters: 64\n\nrvq:\n"
},
{
"path": "config/model/encodec/encodec_large_nq4_s640.yaml",
"chars": 148,
"preview": "# @package __global__\n\ndefaults:\n - encodec/default\n\nseanet:\n ratios: [8, 5, 4, 4]\n n_filters: 64\n\nrvq:\n bins: 2048\n"
},
{
"path": "config/model/lm/audiogen_lm.yaml",
"chars": 738,
"preview": "# @package __global__\n\ndefaults:\n - lm/default\n - override /conditioner: text2sound\n - override /model/lm/model_scale"
},
{
"path": "config/model/lm/default.yaml",
"chars": 1992,
"preview": "# @package __global__\ndefaults:\n - _self_\n - /model/lm/model_scale: base # prefer this group to set model scale instea"
},
{
"path": "config/model/lm/model_scale/base.yaml",
"chars": 102,
"preview": "# @package __global__\n\n# overrides nothing because default is already transformer base (~ 60M params)\n"
},
{
"path": "config/model/lm/model_scale/large.yaml",
"chars": 126,
"preview": "# @package _global_\n\n# gpt2 inspired, even bigger (~3.3B params)\ntransformer_lm:\n dim: 2048\n num_heads: 32\n num_layer"
},
{
"path": "config/model/lm/model_scale/medium.yaml",
"chars": 109,
"preview": "# @package _global_\n\n# gpt2 like (~1.5B params)\ntransformer_lm:\n dim: 1536\n num_heads: 24\n num_layers: 48\n"
},
{
"path": "config/model/lm/model_scale/small.yaml",
"chars": 97,
"preview": "# @package _global_\n\n# 300M Param.\n\ntransformer_lm:\n dim: 1024\n num_heads: 16\n num_layers: 24\n"
},
{
"path": "config/model/lm/model_scale/xsmall.yaml",
"chars": 181,
"preview": "# @package _global_\n# just used for debugging or when we just want to populate the cache\n# and do not care about trainin"
},
{
"path": "config/model/lm/musicgen_lm.yaml",
"chars": 738,
"preview": "# @package __global__\n\ndefaults:\n - lm/default\n - override /conditioner: text2music\n - override /model/lm/model_scale"
},
{
"path": "config/model/none.yaml",
"chars": 157,
"preview": "# @package __global__\n\n# This file exist so that model is recognized as a config group\n# by Hydra, and Dora. A bit weird"
},
{
"path": "config/model/score/basic.yaml",
"chars": 269,
"preview": "# @package _global_\n\ndiffusion_unet:\n hidden: 48\n depth: 4\n res_blocks: 1\n norm_groups: 4\n kernel: 8\n stride: 4\n "
},
{
"path": "config/model/watermark/default.yaml",
"chars": 868,
"preview": "# @package __global__\n\naudioseal:\n autoencoder: seanet\n sample_rate: 16000\n channels: 1\n nbits: 16\n\nseanet:\n dimens"
},
{
"path": "config/solver/audiogen/audiogen_base_16khz.yaml",
"chars": 1772,
"preview": "# @package __global__\n\n# This is the training loop solver\n# for the base AudioGen model (text-to-sound)\n# on monophonic "
},
{
"path": "config/solver/audiogen/debug.yaml",
"chars": 972,
"preview": "# @package __global__\n\n# This is a minimal debugging configuration\n# for MusicGen training solver\ndefaults:\n - audiogen"
},
{
"path": "config/solver/audiogen/default.yaml",
"chars": 769,
"preview": "# @package __global__\n\ndefaults:\n - /solver/musicgen/default\n - _self_\n - /solver/audiogen/evaluation: none\n - overr"
},
{
"path": "config/solver/audiogen/evaluation/none.yaml",
"chars": 67,
"preview": "# @package __global__\n\ndataset:\n evaluate:\n num_samples: 10000\n"
},
{
"path": "config/solver/audiogen/evaluation/objective_eval.yaml",
"chars": 725,
"preview": "# @package __global__\n\n# Setup for execute only on audiocaps for audio generation\n# evaluation with objective metrics\n# "
},
{
"path": "config/solver/compression/debug.yaml",
"chars": 812,
"preview": "# @package __global__\n\ndefaults:\n - compression/default\n - /model: encodec/encodec_base_causal\n - override /dset: aud"
},
{
"path": "config/solver/compression/default.yaml",
"chars": 2955,
"preview": "# @package __global__\n\ndefaults:\n - ../default\n - override /dset: audio/default\n - _self_\n\nsolver: compression\nsample"
},
{
"path": "config/solver/compression/encodec_audiogen_16khz.yaml",
"chars": 177,
"preview": "# @package __global__\n\ndefaults:\n - compression/default\n - /model: encodec/encodec_large_nq4_s320\n - override /dset: "
},
{
"path": "config/solver/compression/encodec_base_24khz.yaml",
"chars": 174,
"preview": "# @package __global__\n\ndefaults:\n - compression/default\n - /model: encodec/encodec_base_causal\n - override /dset: aud"
},
{
"path": "config/solver/compression/encodec_musicgen_32khz.yaml",
"chars": 177,
"preview": "# @package __global__\n\ndefaults:\n - compression/default\n - /model: encodec/encodec_large_nq4_s640\n - override /dset: "
},
{
"path": "config/solver/default.yaml",
"chars": 2534,
"preview": "# @package __global__\n\n# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft\n# Please don"
},
{
"path": "config/solver/diffusion/debug.yaml",
"chars": 1677,
"preview": "# @package __global__\n\ndefaults:\n - /solver/default\n - /model: score/basic\n - override /dset: audio/default\n - _self"
},
{
"path": "config/solver/diffusion/default.yaml",
"chars": 1690,
"preview": "# @package __global__\n\ndefaults:\n - /solver/default\n - /model: score/basic\n - override /dset: audio/default\n - _self"
},
{
"path": "config/solver/diffusion/encodec_24khz.yaml",
"chars": 197,
"preview": "# @package __global__\n\ndefaults:\n - diffusion/default\n - _self_\n\n\nsample_rate: 24000\nchannels: 1\ncompression_model_che"
},
{
"path": "config/solver/jasco/chords.yaml",
"chars": 1972,
"preview": "# @package __global__\n\n# This is the training loop solver\n# for the base MusicGen model (text-to-music)\n# on monophonic "
},
{
"path": "config/solver/jasco/chords_drums.yaml",
"chars": 2203,
"preview": "# @package __global__\n\n# This is the training loop solver\n# for the base MusicGen model (text-to-music)\n# on monophonic "
},
{
"path": "config/solver/jasco/chords_drums_melody.yaml",
"chars": 2518,
"preview": "# @package __global__\n\n# This is the training loop solver\n# for the base MusicGen model (text-to-music)\n# on monophonic "
}
]
// ... and 96 more files (download for full content)
About this extraction
This page contains the full source code of the facebookresearch/audiocraft GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 296 files (5.5 MB), approximately 1.5M tokens, and a symbol index with 1517 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.