main c9ff851e7b50 cached
227 files
1.1 MB
268.6k tokens
1239 symbols
1 requests
Download .txt
Showing preview only (1,174K chars total). Download the full file or copy to clipboard to get everything.
Repository: GrandaddyShmax/audiocraft_plus
Branch: main
Commit: c9ff851e7b50
Files: 227
Total size: 1.1 MB

Directory structure:
gitextract_7_a5iyu7/

├── .github/
│   └── actions/
│       └── audiocraft_build/
│           └── action.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── LICENSE_weights
├── MANIFEST.in
├── Makefile
├── README.md
├── app.py
├── audiocraft/
│   ├── __init__.py
│   ├── adversarial/
│   │   ├── __init__.py
│   │   ├── discriminators/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── mpd.py
│   │   │   ├── msd.py
│   │   │   └── msstftd.py
│   │   └── losses.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── audio.py
│   │   ├── audio_dataset.py
│   │   ├── audio_utils.py
│   │   ├── info_audio_dataset.py
│   │   ├── music_dataset.py
│   │   ├── sound_dataset.py
│   │   └── zip.py
│   ├── environment.py
│   ├── grids/
│   │   ├── __init__.py
│   │   ├── _base_explorers.py
│   │   ├── audiogen/
│   │   │   ├── __init__.py
│   │   │   ├── audiogen_base_16khz.py
│   │   │   └── audiogen_pretrained_16khz_eval.py
│   │   ├── compression/
│   │   │   ├── __init__.py
│   │   │   ├── _explorers.py
│   │   │   ├── debug.py
│   │   │   ├── encodec_audiogen_16khz.py
│   │   │   ├── encodec_base_24khz.py
│   │   │   └── encodec_musicgen_32khz.py
│   │   ├── diffusion/
│   │   │   ├── 4_bands_base_32khz.py
│   │   │   ├── __init__.py
│   │   │   └── _explorers.py
│   │   └── musicgen/
│   │       ├── __init__.py
│   │       ├── _explorers.py
│   │       ├── musicgen_base_32khz.py
│   │       ├── musicgen_base_cached_32khz.py
│   │       ├── musicgen_clapemb_32khz.py
│   │       ├── musicgen_melody_32khz.py
│   │       └── musicgen_pretrained_32khz_eval.py
│   ├── losses/
│   │   ├── __init__.py
│   │   ├── balancer.py
│   │   ├── sisnr.py
│   │   ├── specloss.py
│   │   └── stftloss.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── chroma_cosinesim.py
│   │   ├── clap_consistency.py
│   │   ├── fad.py
│   │   ├── kld.py
│   │   ├── rvm.py
│   │   └── visqol.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── audiogen.py
│   │   ├── builders.py
│   │   ├── encodec.py
│   │   ├── lm.py
│   │   ├── loaders.py
│   │   ├── multibanddiffusion.py
│   │   ├── musicgen.py
│   │   └── unet.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── activations.py
│   │   ├── chroma.py
│   │   ├── codebooks_patterns.py
│   │   ├── conditioners.py
│   │   ├── conv.py
│   │   ├── diffusion_schedule.py
│   │   ├── lstm.py
│   │   ├── rope.py
│   │   ├── seanet.py
│   │   ├── streaming.py
│   │   └── transformer.py
│   ├── optim/
│   │   ├── __init__.py
│   │   ├── cosine_lr_scheduler.py
│   │   ├── dadam.py
│   │   ├── ema.py
│   │   ├── fsdp.py
│   │   ├── inverse_sqrt_lr_scheduler.py
│   │   ├── linear_warmup_lr_scheduler.py
│   │   └── polynomial_decay_lr_scheduler.py
│   ├── py.typed
│   ├── quantization/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── core_vq.py
│   │   └── vq.py
│   ├── solvers/
│   │   ├── __init__.py
│   │   ├── audiogen.py
│   │   ├── base.py
│   │   ├── builders.py
│   │   ├── compression.py
│   │   ├── diffusion.py
│   │   └── musicgen.py
│   ├── train.py
│   └── utils/
│       ├── __init__.py
│       ├── autocast.py
│       ├── best_state.py
│       ├── cache.py
│       ├── checkpoint.py
│       ├── cluster.py
│       ├── deadlock.py
│       ├── export.py
│       ├── export_legacy.py
│       ├── notebook.py
│       ├── profiler.py
│       ├── samples/
│       │   ├── __init__.py
│       │   └── manager.py
│       ├── ui.py
│       └── utils.py
├── config/
│   ├── conditioner/
│   │   ├── chroma2music.yaml
│   │   ├── clapemb2music.yaml
│   │   ├── none.yaml
│   │   ├── text2music.yaml
│   │   └── text2sound.yaml
│   ├── config.yaml
│   ├── dset/
│   │   ├── audio/
│   │   │   ├── audiocaps_16khz.yaml
│   │   │   ├── default.yaml
│   │   │   ├── example.yaml
│   │   │   └── musiccaps_32khz.yaml
│   │   ├── default.yaml
│   │   └── internal/
│   │       ├── music_10k_32khz.yaml
│   │       ├── music_400k_32khz.yaml
│   │       └── sounds_16khz.yaml
│   ├── model/
│   │   ├── encodec/
│   │   │   ├── default.yaml
│   │   │   ├── encodec_base_causal.yaml
│   │   │   ├── encodec_large_nq4_s320.yaml
│   │   │   └── encodec_large_nq4_s640.yaml
│   │   ├── lm/
│   │   │   ├── audiogen_lm.yaml
│   │   │   ├── default.yaml
│   │   │   ├── model_scale/
│   │   │   │   ├── base.yaml
│   │   │   │   ├── large.yaml
│   │   │   │   ├── medium.yaml
│   │   │   │   ├── small.yaml
│   │   │   │   └── xsmall.yaml
│   │   │   └── musicgen_lm.yaml
│   │   ├── none.yaml
│   │   └── score/
│   │       └── basic.yaml
│   ├── solver/
│   │   ├── audiogen/
│   │   │   ├── audiogen_base_16khz.yaml
│   │   │   ├── debug.yaml
│   │   │   ├── default.yaml
│   │   │   └── evaluation/
│   │   │       ├── none.yaml
│   │   │       └── objective_eval.yaml
│   │   ├── compression/
│   │   │   ├── debug.yaml
│   │   │   ├── default.yaml
│   │   │   ├── encodec_audiogen_16khz.yaml
│   │   │   ├── encodec_base_24khz.yaml
│   │   │   └── encodec_musicgen_32khz.yaml
│   │   ├── default.yaml
│   │   ├── diffusion/
│   │   │   ├── debug.yaml
│   │   │   ├── default.yaml
│   │   │   └── encodec_24khz.yaml
│   │   └── musicgen/
│   │       ├── debug.yaml
│   │       ├── default.yaml
│   │       ├── evaluation/
│   │       │   ├── none.yaml
│   │       │   └── objective_eval.yaml
│   │       ├── musicgen_base_32khz.yaml
│   │       └── musicgen_melody_32khz.yaml
│   └── teams/
│       ├── default.yaml
│       └── labs.yaml
├── dataset/
│   └── example/
│       ├── electro_1.json
│       └── electro_2.json
├── demos/
│   ├── audiogen_demo.ipynb
│   ├── musicgen_app.py
│   └── musicgen_demo.ipynb
├── dockerignore
├── docs/
│   ├── AUDIOGEN.md
│   ├── CONDITIONING.md
│   ├── DATASETS.md
│   ├── ENCODEC.md
│   ├── MBD.md
│   ├── METRICS.md
│   ├── MUSICGEN.md
│   └── TRAINING.md
├── egs/
│   └── example/
│       └── data.jsonl
├── model_cards/
│   ├── AUDIOGEN_MODEL_CARD.md
│   └── MUSICGEN_MODEL_CARD.md
├── models/
│   └── Put your models here.txt
├── mypy.ini
├── requirements.txt
├── scripts/
│   ├── __init__.py
│   ├── mos.py
│   ├── resample_dataset.py
│   ├── static/
│   │   └── style.css
│   └── templates/
│       ├── base.html
│       ├── index.html
│       ├── login.html
│       ├── results.html
│       └── survey.html
├── setup.cfg
├── setup.py
└── tests/
    ├── __init__.py
    ├── adversarial/
    │   ├── __init__.py
    │   ├── test_discriminators.py
    │   └── test_losses.py
    ├── common_utils/
    │   ├── __init__.py
    │   ├── temp_utils.py
    │   └── wav_utils.py
    ├── data/
    │   ├── __init__.py
    │   ├── test_audio.py
    │   ├── test_audio_dataset.py
    │   └── test_audio_utils.py
    ├── losses/
    │   ├── __init__.py
    │   └── test_losses.py
    ├── models/
    │   ├── test_audiogen.py
    │   ├── test_encodec_model.py
    │   ├── test_multibanddiffusion.py
    │   └── test_musicgen.py
    ├── modules/
    │   ├── __init__.py
    │   ├── test_activations.py
    │   ├── test_codebooks_patterns.py
    │   ├── test_conv.py
    │   ├── test_lstm.py
    │   ├── test_rope.py
    │   ├── test_seanet.py
    │   └── test_transformer.py
    ├── quantization/
    │   └── test_vq.py
    └── utils/
        └── __init__.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/actions/audiocraft_build/action.yml
================================================
name: audiocraft_build
description: 'Build audiocraft env.'
runs:
  using: "composite"
  steps:
  - uses: actions/setup-python@v2
    with:
      python-version: 3.8
  - uses: actions/cache@v2
    id: cache
    with:
      path: env
      key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}

  - if: ${{ steps.cache.outputs.cache-hit != 'true' }}
    name: Install dependencies
    shell: bash
    run: |
      sudo apt-get update
      sudo apt-get install libsndfile1-dev ffmpeg
      python3 -m venv env
      .  env/bin/activate
      python -m pip install --upgrade pip
      pip install -e '.[dev]'
  - name: System Dependencies
    shell: bash
    run: |
      sudo apt-get update
      sudo apt-get install libsndfile1-dev ffmpeg


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__
*.py[cod]
*$py.class

# C extensions
*.so

# macOS dir files
.DS_Store

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
.ipynb_checkpoints

# Tests and linter
.pytest_cache/
.mypy_cache/
.coverage

# docs
/api_docs

# dotenv
.env
.envrc

# virtualenv
.venv
venv/
ENV/

# egs with manifest files
egs/*
!egs/example
# local datasets
dataset/*
!dataset/example

# personal notebooks & scripts
*/local_scripts
*/notes
.vscode/
/notebooks
/local_scripts
/notes
/cache

================================================
FILE: CHANGELOG.md
================================================
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.0.0] - 2023-08-02

Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Added pretrained model for AudioGen and MultiBandDiffusion.

## [0.0.2] - 2023-08-01

Improved demo, fixed top p (thanks @jnordberg).

Compressor tanh on output to avoid clipping with some style (especially piano).
Now repeating the conditioning periodically if it is too short.

More options when launching Gradio app locally (thanks @ashleykleynhans).

Testing out PyTorch 2.0 memory efficient attention.

Added extended generation (infinite length) by slowly moving the windows.
Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.

## [0.0.1] - 2023-06-09

Initial release, with model evaluation only.


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.

This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to AudioCraft

We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests

AudioCraft is the implementation of a research paper.
Therefore, we do not plan on accepting many pull requests for new features.
We certainly welcome them for bug fixes.

1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## License
By contributing to encodec, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.


================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:11.8.0-base-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive \
    PYTHONUNBUFFERED=1 \
    PYTHONIOENCODING=UTF-8
RUN --mount=type=cache,target=/var/cache/apt --mount=type=cache,target=/var/lib/apt apt update &&\
    apt install -y \
    wget \
    git \
    pkg-config \
    python3 \
    python3-pip \
    python-is-python3 \
    ffmpeg \
    libnvrtc11.2 \
    libtcmalloc-minimal4

RUN useradd -m -u 1000 ac
RUN --mount=type=cache,target=/root/.cache python -m pip install --upgrade pip wheel
ENV TORCH_COMMAND="pip install torch==2.0.1+cu118 torchaudio --extra-index-url https://download.pytorch.org/whl/cu118"
RUN --mount=type=cache,target=/root/.cache python -m $TORCH_COMMAND
RUN ln -s /usr/lib/x86_64-linux-gnu/libnvrtc.so.11.2 /usr/lib/x86_64-linux-gnu/libnvrtc.so
USER 1000
RUN mkdir ~/.cache
RUN --mount=type=cache,target=/home/ac/.cache --mount=source=.,target=/home/ac/audiocraft python -m pip install -r /home/ac/audiocraft/requirements.txt
WORKDIR /home/ac/audiocraft

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) Meta Platforms, Inc. and affiliates.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: LICENSE_weights
================================================
Attribution-NonCommercial 4.0 International

=======================================================================

Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.

Using Creative Commons Public Licenses

Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.

     Considerations for licensors: Our public licenses are
     intended for use by those authorized to give the public
     permission to use material in ways otherwise restricted by
     copyright and certain other rights. Our licenses are
     irrevocable. Licensors should read and understand the terms
     and conditions of the license they choose before applying it.
     Licensors should also secure all rights necessary before
     applying our licenses so that the public can reuse the
     material as expected. Licensors should clearly mark any
     material not subject to the license. This includes other CC-
     licensed material, or material used under an exception or
     limitation to copyright. More considerations for licensors:
	wiki.creativecommons.org/Considerations_for_licensors

     Considerations for the public: By using one of our public
     licenses, a licensor grants the public permission to use the
     licensed material under specified terms and conditions. If
     the licensor's permission is not necessary for any reason--for
     example, because of any applicable exception or limitation to
     copyright--then that use is not regulated by the license. Our
     licenses grant only permissions under copyright and certain
     other rights that a licensor has authority to grant. Use of
     the licensed material may still be restricted for other
     reasons, including because others have copyright or other
     rights in the material. A licensor may make special requests,
     such as asking that all changes be marked or described.
     Although not required by our licenses, you are encouraged to
     respect those requests where reasonable. More_considerations
     for the public: 
	wiki.creativecommons.org/Considerations_for_licensees

=======================================================================

Creative Commons Attribution-NonCommercial 4.0 International Public
License

By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.

Section 1 -- Definitions.

  a. Adapted Material means material subject to Copyright and Similar
     Rights that is derived from or based upon the Licensed Material
     and in which the Licensed Material is translated, altered,
     arranged, transformed, or otherwise modified in a manner requiring
     permission under the Copyright and Similar Rights held by the
     Licensor. For purposes of this Public License, where the Licensed
     Material is a musical work, performance, or sound recording,
     Adapted Material is always produced where the Licensed Material is
     synched in timed relation with a moving image.

  b. Adapter's License means the license You apply to Your Copyright
     and Similar Rights in Your contributions to Adapted Material in
     accordance with the terms and conditions of this Public License.

  c. Copyright and Similar Rights means copyright and/or similar rights
     closely related to copyright including, without limitation,
     performance, broadcast, sound recording, and Sui Generis Database
     Rights, without regard to how the rights are labeled or
     categorized. For purposes of this Public License, the rights
     specified in Section 2(b)(1)-(2) are not Copyright and Similar
     Rights.
  d. Effective Technological Measures means those measures that, in the
     absence of proper authority, may not be circumvented under laws
     fulfilling obligations under Article 11 of the WIPO Copyright
     Treaty adopted on December 20, 1996, and/or similar international
     agreements.

  e. Exceptions and Limitations means fair use, fair dealing, and/or
     any other exception or limitation to Copyright and Similar Rights
     that applies to Your use of the Licensed Material.

  f. Licensed Material means the artistic or literary work, database,
     or other material to which the Licensor applied this Public
     License.

  g. Licensed Rights means the rights granted to You subject to the
     terms and conditions of this Public License, which are limited to
     all Copyright and Similar Rights that apply to Your use of the
     Licensed Material and that the Licensor has authority to license.

  h. Licensor means the individual(s) or entity(ies) granting rights
     under this Public License.

  i. NonCommercial means not primarily intended for or directed towards
     commercial advantage or monetary compensation. For purposes of
     this Public License, the exchange of the Licensed Material for
     other material subject to Copyright and Similar Rights by digital
     file-sharing or similar means is NonCommercial provided there is
     no payment of monetary compensation in connection with the
     exchange.

  j. Share means to provide material to the public by any means or
     process that requires permission under the Licensed Rights, such
     as reproduction, public display, public performance, distribution,
     dissemination, communication, or importation, and to make material
     available to the public including in ways that members of the
     public may access the material from a place and at a time
     individually chosen by them.

  k. Sui Generis Database Rights means rights other than copyright
     resulting from Directive 96/9/EC of the European Parliament and of
     the Council of 11 March 1996 on the legal protection of databases,
     as amended and/or succeeded, as well as other essentially
     equivalent rights anywhere in the world.

  l. You means the individual or entity exercising the Licensed Rights
     under this Public License. Your has a corresponding meaning.

Section 2 -- Scope.

  a. License grant.

       1. Subject to the terms and conditions of this Public License,
          the Licensor hereby grants You a worldwide, royalty-free,
          non-sublicensable, non-exclusive, irrevocable license to
          exercise the Licensed Rights in the Licensed Material to:

            a. reproduce and Share the Licensed Material, in whole or
               in part, for NonCommercial purposes only; and

            b. produce, reproduce, and Share Adapted Material for
               NonCommercial purposes only.

       2. Exceptions and Limitations. For the avoidance of doubt, where
          Exceptions and Limitations apply to Your use, this Public
          License does not apply, and You do not need to comply with
          its terms and conditions.

       3. Term. The term of this Public License is specified in Section
          6(a).

       4. Media and formats; technical modifications allowed. The
          Licensor authorizes You to exercise the Licensed Rights in
          all media and formats whether now known or hereafter created,
          and to make technical modifications necessary to do so. The
          Licensor waives and/or agrees not to assert any right or
          authority to forbid You from making technical modifications
          necessary to exercise the Licensed Rights, including
          technical modifications necessary to circumvent Effective
          Technological Measures. For purposes of this Public License,
          simply making modifications authorized by this Section 2(a)
          (4) never produces Adapted Material.

       5. Downstream recipients.

            a. Offer from the Licensor -- Licensed Material. Every
               recipient of the Licensed Material automatically
               receives an offer from the Licensor to exercise the
               Licensed Rights under the terms and conditions of this
               Public License.

            b. No downstream restrictions. You may not offer or impose
               any additional or different terms or conditions on, or
               apply any Effective Technological Measures to, the
               Licensed Material if doing so restricts exercise of the
               Licensed Rights by any recipient of the Licensed
               Material.

       6. No endorsement. Nothing in this Public License constitutes or
          may be construed as permission to assert or imply that You
          are, or that Your use of the Licensed Material is, connected
          with, or sponsored, endorsed, or granted official status by,
          the Licensor or others designated to receive attribution as
          provided in Section 3(a)(1)(A)(i).

  b. Other rights.

       1. Moral rights, such as the right of integrity, are not
          licensed under this Public License, nor are publicity,
          privacy, and/or other similar personality rights; however, to
          the extent possible, the Licensor waives and/or agrees not to
          assert any such rights held by the Licensor to the limited
          extent necessary to allow You to exercise the Licensed
          Rights, but not otherwise.

       2. Patent and trademark rights are not licensed under this
          Public License.

       3. To the extent possible, the Licensor waives any right to
          collect royalties from You for the exercise of the Licensed
          Rights, whether directly or through a collecting society
          under any voluntary or waivable statutory or compulsory
          licensing scheme. In all other cases the Licensor expressly
          reserves any right to collect such royalties, including when
          the Licensed Material is used other than for NonCommercial
          purposes.

Section 3 -- License Conditions.

Your exercise of the Licensed Rights is expressly made subject to the
following conditions.

  a. Attribution.

       1. If You Share the Licensed Material (including in modified
          form), You must:

            a. retain the following if it is supplied by the Licensor
               with the Licensed Material:

                 i. identification of the creator(s) of the Licensed
                    Material and any others designated to receive
                    attribution, in any reasonable manner requested by
                    the Licensor (including by pseudonym if
                    designated);

                ii. a copyright notice;

               iii. a notice that refers to this Public License;

                iv. a notice that refers to the disclaimer of
                    warranties;

                 v. a URI or hyperlink to the Licensed Material to the
                    extent reasonably practicable;

            b. indicate if You modified the Licensed Material and
               retain an indication of any previous modifications; and

            c. indicate the Licensed Material is licensed under this
               Public License, and include the text of, or the URI or
               hyperlink to, this Public License.

       2. You may satisfy the conditions in Section 3(a)(1) in any
          reasonable manner based on the medium, means, and context in
          which You Share the Licensed Material. For example, it may be
          reasonable to satisfy the conditions by providing a URI or
          hyperlink to a resource that includes the required
          information.

       3. If requested by the Licensor, You must remove any of the
          information required by Section 3(a)(1)(A) to the extent
          reasonably practicable.

       4. If You Share Adapted Material You produce, the Adapter's
          License You apply must not prevent recipients of the Adapted
          Material from complying with this Public License.

Section 4 -- Sui Generis Database Rights.

Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:

  a. for the avoidance of doubt, Section 2(a)(1) grants You the right
     to extract, reuse, reproduce, and Share all or a substantial
     portion of the contents of the database for NonCommercial purposes
     only;

  b. if You include all or a substantial portion of the database
     contents in a database in which You have Sui Generis Database
     Rights, then the database in which You have Sui Generis Database
     Rights (but not its individual contents) is Adapted Material; and

  c. You must comply with the conditions in Section 3(a) if You Share
     all or a substantial portion of the contents of the database.

For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.

Section 5 -- Disclaimer of Warranties and Limitation of Liability.

  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.

  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.

  c. The disclaimer of warranties and limitation of liability provided
     above shall be interpreted in a manner that, to the extent
     possible, most closely approximates an absolute disclaimer and
     waiver of all liability.

Section 6 -- Term and Termination.

  a. This Public License applies for the term of the Copyright and
     Similar Rights licensed here. However, if You fail to comply with
     this Public License, then Your rights under this Public License
     terminate automatically.

  b. Where Your right to use the Licensed Material has terminated under
     Section 6(a), it reinstates:

       1. automatically as of the date the violation is cured, provided
          it is cured within 30 days of Your discovery of the
          violation; or

       2. upon express reinstatement by the Licensor.

     For the avoidance of doubt, this Section 6(b) does not affect any
     right the Licensor may have to seek remedies for Your violations
     of this Public License.

  c. For the avoidance of doubt, the Licensor may also offer the
     Licensed Material under separate terms or conditions or stop
     distributing the Licensed Material at any time; however, doing so
     will not terminate this Public License.

  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
     License.

Section 7 -- Other Terms and Conditions.

  a. The Licensor shall not be bound by any additional or different
     terms or conditions communicated by You unless expressly agreed.

  b. Any arrangements, understandings, or agreements regarding the
     Licensed Material not stated herein are separate from and
     independent of the terms and conditions of this Public License.

Section 8 -- Interpretation.

  a. For the avoidance of doubt, this Public License does not, and
     shall not be interpreted to, reduce, limit, restrict, or impose
     conditions on any use of the Licensed Material that could lawfully
     be made without permission under this Public License.

  b. To the extent possible, if any provision of this Public License is
     deemed unenforceable, it shall be automatically reformed to the
     minimum extent necessary to make it enforceable. If the provision
     cannot be reformed, it shall be severed from this Public License
     without affecting the enforceability of the remaining terms and
     conditions.

  c. No term or condition of this Public License will be waived and no
     failure to comply consented to unless expressly agreed to by the
     Licensor.

  d. Nothing in this Public License constitutes or may be interpreted
     as a limitation upon, or waiver of, any privileges and immunities
     that apply to the Licensor or You, including from the legal
     processes of any jurisdiction or authority.

=======================================================================

Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.

Creative Commons may be contacted at creativecommons.org.


================================================
FILE: MANIFEST.in
================================================
include Makefile
include LICENSE
include LICENSE_weights
include *.md
include *.ini
include requirements.txt
include audiocraft/py.typed
include assets/*.mp3
recursive-include conf *.yaml


================================================
FILE: Makefile
================================================
INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
	dataset.train.num_samples=10 dataset.valid.num_samples=10 \
	dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
	logging.level=DEBUG
INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true   # SIG is 5091833e
INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
	transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false  # Using compression model from 5091833e
INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
	transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false  # Using compression model from 5091833e
INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example  \
	checkpoint.save_last=false  # Using compression model from 616d7b3c

default: linter tests

install:
	pip install -U pip
	pip install -U -e '.[dev]'

linter:
	flake8 audiocraft && mypy audiocraft
	flake8 tests && mypy tests

tests:
	coverage run -m pytest tests
	coverage report

tests_integ:
	$(INTEG_COMPRESSION)
	$(INTEG_MBD)
	$(INTEG_MUSICGEN)
	$(INTEG_AUDIOGEN)


api_docs:
	pdoc3 --html -o api_docs -f audiocraft

dist:
	python setup.py sdist

.PHONY: linter tests api_docs dist


================================================
FILE: README.md
================================================
# AudioCraft Plus
![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg)
![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg)

AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.

<a target="_blank" href="https://colab.research.google.com/github/camenduru/MusicGen-colab/blob/main/MusicGen_ClownOfMadness_plus_colab.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
<a target="_blank" href="https://huggingface.co/spaces/facebook/MusicGen">
  <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
</a>
<br>
<br>

![image](https://github.com/GrandaddyShmax/audiocraft_plus/assets/52707645/043fc037-54a9-48c4-bb5c-bf9b7440d146)


## Features
AudioCraft Plus is an all-in-one WebUI for the original AudioCraft, adding many quality features on top.

- AudioGen Model
- Multiband Diffusion
- Custom Model Support
- Generation Metadata and Audio Info tab
- Mono to Stereo
- Multiprompt/Prompt Segmentation with Structure Prompts
- Video Output Customization
- Music Continuation

## Installation
If you are updating from the previous version of AudioCraft Plus, do the following steps in the AudioCraft Plus folder:
```shell
git pull
pip install transformers --upgrade
pip install  torchmetrics --upgrade
```

#### Otherwise: Clean Installation  
AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following:

```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
# Then proceed to one of the following
pip install -U audiocraft  # stable release
pip install -U git+https://git@github.com/GrandaddyShmax/audiocraft_plus#egg=audiocraft  # bleeding edge
pip install -e .  # or if you cloned the repo locally (mandatory if you want to train).
```

We also recommend having `ffmpeg` installed, either through your system or Anaconda:
```bash
sudo apt-get install ffmpeg
# Or if you are using Anaconda or Miniconda
conda install 'ffmpeg<5' -c  conda-forge
```

Installation video thanks to Pogs Cafe:  
[![Untitled](http://img.youtube.com/vi/WjGk4bcbUOI/0.jpg)](http://www.youtube.com/watch?v=WjGk4bcbUOI "Installing MusicGen+ Locally")


Additional installation guide by [radaevm](https://github.com/radaevm) can be found [HERE](https://github.com/GrandaddyShmax/audiocraft_plus/discussions/31)


## Models

At the moment, AudioCraft contains the training code and inference code for:
* [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
* [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
* [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.

## Training code

AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
the [AudioCraft training documentation](./docs/TRAINING.md).

For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
that provides pointers to configuration, example grids and model/task-specific information and FAQ.


## API documentation

We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.


## FAQ

#### Is the training code available?

Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md).

#### Where are the models stored?

Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable.


## License
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
* The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).


## Citation

For the general framework of AudioCraft, please cite the following.
```
@article{copet2023simple,
    title={Simple and Controllable Music Generation},
    author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
    year={2023},
    journal={arXiv preprint arXiv:2306.05284},
}
```

When referring to a specific model, please cite as mentioned in the model specific README, e.g
[./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.


================================================
FILE: app.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py
# also released under the MIT license.

import argparse
from concurrent.futures import ProcessPoolExecutor
import os
from pathlib import Path
import subprocess as sp
from tempfile import NamedTemporaryFile
import time
import warnings
import glob
import re
from PIL import Image
from pydub import AudioSegment
from datetime import datetime

import json
import shutil
import taglib
import torch
import torchaudio
import gradio as gr
import numpy as np
import typing as tp

from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion
from audiocraft.utils import ui
import random, string

version = "2.0.1"

theme = gr.themes.Base(
    primary_hue="lime",
    secondary_hue="lime",
    neutral_hue="neutral",
).set(
    button_primary_background_fill_hover='*primary_500',
    button_primary_background_fill_hover_dark='*primary_500',
    button_secondary_background_fill_hover='*primary_500',
    button_secondary_background_fill_hover_dark='*primary_500'
)

MODEL = None  # Last used model
MODELS = None
UNLOAD_MODEL = False
MOVE_TO_CPU = False
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
print(IS_BATCHED)
MAX_BATCH_SIZE = 12
BATCHED_DURATION = 15
INTERRUPTING = False
MBD = None
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
_old_call = sp.call


def generate_random_string(length):
    characters = string.ascii_letters + string.digits
    return ''.join(random.choice(characters) for _ in range(length))


def resize_video(input_path, output_path, target_width, target_height):
    ffmpeg_cmd = [
        'ffmpeg',
        '-y',
        '-i', input_path,
        '-vf', f'scale={target_width}:{target_height}',
        '-c:a', 'copy',
        output_path
    ]
    sp.run(ffmpeg_cmd)


def _call_nostderr(*args, **kwargs):
    # Avoid ffmpeg vomiting on the logs.
    kwargs['stderr'] = sp.DEVNULL
    kwargs['stdout'] = sp.DEVNULL
    _old_call(*args, **kwargs)


sp.call = _call_nostderr
# Preallocating the pool of processes.
pool = ProcessPoolExecutor(4)
pool.__enter__()


def interrupt():
    global INTERRUPTING
    INTERRUPTING = True


class FileCleaner:
    def __init__(self, file_lifetime: float = 3600):
        self.file_lifetime = file_lifetime
        self.files = []

    def add(self, path: tp.Union[str, Path]):
        self._cleanup()
        self.files.append((time.time(), Path(path)))

    def _cleanup(self):
        now = time.time()
        for time_added, path in list(self.files):
            if now - time_added > self.file_lifetime:
                if path.exists():
                    path.unlink()
                self.files.pop(0)
            else:
                break


file_cleaner = FileCleaner()


def make_waveform(*args, **kwargs):
    # Further remove some warnings.
    be = time.time()
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        height = kwargs.pop('height')
        width = kwargs.pop('width')
        if height < 256:
            height = 256
        if width < 256:
            width = 256
        waveform_video = gr.make_waveform(*args, **kwargs)
        out = f"{generate_random_string(12)}.mp4"
        image = kwargs.get('bg_image', None)
        if image is None:
            resize_video(waveform_video, out, 900, 300)
        else:
            resize_video(waveform_video, out, width, height)
        print("Make a video took", time.time() - be)
        return out


def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=None, gen_type="music"):
    global MODEL, MODELS
    print("Loading model", version)
    if MODELS is None:
        if version == 'GrandaddyShmax/musicgen-custom':
            MODEL = MusicGen.get_pretrained(custom_model)
        else:
            if gen_type == "music":
                MODEL = MusicGen.get_pretrained(version)
            elif gen_type == "audio":
                MODEL = AudioGen.get_pretrained(version)

        return

    else:
        t1 = time.monotonic()
        if MODEL is not None:
            MODEL.to('cpu') # move to cache
            print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
            t1 = time.monotonic()
        if version != 'GrandaddyShmax/musicgen-custom' and MODELS.get(version) is None:
            print("Loading model %s from disk" % version)
            if gen_type == "music":
                result = MusicGen.get_pretrained(version)
            elif gen_type == "audio":
                result = AudioGen.get_pretrained(version)
            MODELS[version] = result
            print("Model loaded in %.2fs" % (time.monotonic() - t1))
            MODEL = result
            return
        result = MODELS[version].to('cuda')
        print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
        MODEL = result

def get_audio_info(audio_path):
    if audio_path is not None:
        if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
            if not audio_path.name.endswith(".json"):
                with taglib.File(audio_path.name, save_on_exit=False) as song:
                    if 'COMMENT' not in song.tags:
                        return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
                    json_string = song.tags['COMMENT'][0]
                    data = json.loads(json_string)
                    global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
                    bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
                    key = str("\nKey: " + data['key']) if 'key' in data else ""
                    scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
                    prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
                    duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
                    overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
                    seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
                    audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
                    input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
                    channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
                    sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
                    gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
                    model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
                    custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
                    decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
                    topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
                    topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
                    temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
                    cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
                    version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
                    info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + decoder + topk + topp + temperature + cfg_coef)
                    if info == "":
                        return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)"
                    return info
            else:
                with open(audio_path.name) as json_file:
                    data = json.load(json_file)
                    #if 'global_prompt' not in data:
                        #return "No tags found. Either the file is not generated by MusicGen+ V1.2.8a and higher or the tags are corrupted."
                    global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else ""
                    bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else ""
                    key = str("\nKey: " + data['key']) if 'key' in data else ""
                    scale = str("\nScale: " + data['scale']) if 'scale' in data else ""
                    prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else ""
                    duration = str("\nDuration: " + data['duration']) if 'duration' in data else ""
                    overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else ""
                    seed = str("\nSeed: " + data['seed']) if 'seed' in data else ""
                    audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else ""
                    input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else ""
                    channel = str("\nChannel: " + data['channel']) if 'channel' in data else ""
                    sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else ""
                    gen_type = str(data['generator'] + "gen-") if 'generator' in data else ""
                    model = str("\nModel: " + gen_type + data['model']) if 'model' in data else ""
                    custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else ""
                    decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else ""
                    topk = str("\nTopk: " + data['topk']) if 'topk' in data else ""
                    topp = str("\nTopp: " + data['topp']) if 'topp' in data else ""
                    temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else ""
                    cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else ""
                    version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown"
                    info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + decoder + topk + topp + temperature + cfg_coef)
                    if info == "":
                        return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted."
                    return info
        else:
            return "Only .wav ,.mp4 and .json files are supported"
    else:
        return None


def info_to_params(audio_path):
    if audio_path is not None:
        if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
            if not audio_path.name.endswith(".json"):
                with taglib.File(audio_path.name, save_on_exit=False) as song:
                    if 'COMMENT' not in song.tags:
                        return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
                    json_string = song.tags['COMMENT'][0]
                    data = json.loads(json_string)
                    struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
                    global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
                    bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
                    key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
                    scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
                    model = data['model'] if 'model' in data else "large"
                    custom_model = (data['custom_model'] if (data['custom_model']) in get_available_folders() else None) if 'custom_model' in data else None
                    decoder = data['decoder'] if 'decoder' in data else "Default"
                    if 'texts' not in data:
                        unique_prompts = 1
                        text = ["", "", "", "", "", "", "", "", "", ""]
                        repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
                    else:
                        s = data['texts']
                        s = re.findall(r"'(.*?)'", s)
                        text = []
                        repeat = []
                        i = 0
                        for elem in s:
                            if elem.strip():
                                if i == 0 or elem != s[i-1]:
                                    text.append(elem)
                                    repeat.append(1)
                                else:
                                    repeat[-1] += 1
                            i += 1
                        text.extend([""] * (10 - len(text)))
                        repeat.extend([1] * (10 - len(repeat)))
                        unique_prompts = len([t for t in text if t])
                    audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
                    duration = int(data['duration']) if 'duration' in data else 10
                    topk = float(data['topk']) if 'topk' in data else 250
                    topp = float(data['topp']) if 'topp' in data else 0
                    temperature = float(data['temperature']) if 'temperature' in data else 1.0
                    cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
                    seed = int(data['seed']) if 'seed' in data else -1
                    overlap = int(data['overlap']) if 'overlap' in data else 12
                    channel = data['channel'] if 'channel' in data else "stereo"
                    sr_select = data['sr_select'] if 'sr_select' in data else "48000"
                    return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
            else:
                with open(audio_path.name) as json_file:
                    data = json.load(json_file)
                    struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False
                    global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
                    bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120
                    key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C"
                    scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major"
                    model = data['model'] if 'model' in data else "large"
                    custom_model = (data['custom_model'] if data['custom_model'] in get_available_folders() else None) if 'custom_model' in data else None
                    decoder = data['decoder'] if 'decoder' in data else "Default"
                    if 'texts' not in data:
                        unique_prompts = 1
                        text = ["", "", "", "", "", "", "", "", "", ""]
                        repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
                    else:
                        s = data['texts']
                        s = re.findall(r"'(.*?)'", s)
                        text = []
                        repeat = []
                        i = 0
                        for elem in s:
                            if elem.strip():
                                if i == 0 or elem != s[i-1]:
                                    text.append(elem)
                                    repeat.append(1)
                                else:
                                    repeat[-1] += 1
                            i += 1
                        text.extend([""] * (10 - len(text)))
                        repeat.extend([1] * (10 - len(repeat)))
                        unique_prompts = len([t for t in text if t])
                    audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample"
                    duration = int(data['duration']) if 'duration' in data else 10
                    topk = float(data['topk']) if 'topk' in data else 250
                    topp = float(data['topp']) if 'topp' in data else 0
                    temperature = float(data['temperature']) if 'temperature' in data else 1.0
                    cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
                    seed = int(data['seed']) if 'seed' in data else -1
                    overlap = int(data['overlap']) if 'overlap' in data else 12
                    channel = data['channel'] if 'channel' in data else "stereo"
                    sr_select = data['sr_select'] if 'sr_select' in data else "48000"
                    return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
        else:
            return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
    else:
        return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"


def info_to_params_a(audio_path):
    if audio_path is not None:
        if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"):
            if not audio_path.name.endswith(".json"):
                with taglib.File(audio_path.name, save_on_exit=False) as song:
                    if 'COMMENT' not in song.tags:
                        return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
                    json_string = song.tags['COMMENT'][0]
                    data = json.loads(json_string)
                    struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
                    global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
                    decoder = data['decoder'] if 'decoder' in data else "Default"
                    if 'texts' not in data:
                        unique_prompts = 1
                        text = ["", "", "", "", "", "", "", "", "", ""]
                        repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
                    else:
                        s = data['texts']
                        s = re.findall(r"'(.*?)'", s)
                        text = []
                        repeat = []
                        i = 0
                        for elem in s:
                            if elem.strip():
                                if i == 0 or elem != s[i-1]:
                                    text.append(elem)
                                    repeat.append(1)
                                else:
                                    repeat[-1] += 1
                            i += 1
                        text.extend([""] * (10 - len(text)))
                        repeat.extend([1] * (10 - len(repeat)))
                        unique_prompts = len([t for t in text if t])
                    duration = int(data['duration']) if 'duration' in data else 10
                    topk = float(data['topk']) if 'topk' in data else 250
                    topp = float(data['topp']) if 'topp' in data else 0
                    temperature = float(data['temperature']) if 'temperature' in data else 1.0
                    cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
                    seed = int(data['seed']) if 'seed' in data else -1
                    overlap = int(data['overlap']) if 'overlap' in data else 12
                    channel = data['channel'] if 'channel' in data else "stereo"
                    sr_select = data['sr_select'] if 'sr_select' in data else "48000"
                    return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
            else:
                with open(audio_path.name) as json_file:
                    data = json.load(json_file)
                    struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False
                    global_prompt = data['global_prompt'] if 'global_prompt' in data else ""
                    decoder = data['decoder'] if 'decoder' in data else "Default"
                    if 'texts' not in data:
                        unique_prompts = 1
                        text = ["", "", "", "", "", "", "", "", "", ""]
                        repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
                    else:
                        s = data['texts']
                        s = re.findall(r"'(.*?)'", s)
                        text = []
                        repeat = []
                        i = 0
                        for elem in s:
                            if elem.strip():
                                if i == 0 or elem != s[i-1]:
                                    text.append(elem)
                                    repeat.append(1)
                                else:
                                    repeat[-1] += 1
                            i += 1
                        text.extend([""] * (10 - len(text)))
                        repeat.extend([1] * (10 - len(repeat)))
                        unique_prompts = len([t for t in text if t])
                    duration = int(data['duration']) if 'duration' in data else 10
                    topk = float(data['topk']) if 'topk' in data else 250
                    topp = float(data['topp']) if 'topp' in data else 0
                    temperature = float(data['temperature']) if 'temperature' in data else 1.0
                    cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0
                    seed = int(data['seed']) if 'seed' in data else -1
                    overlap = int(data['overlap']) if 'overlap' in data else 12
                    channel = data['channel'] if 'channel' in data else "stereo"
                    sr_select = data['sr_select'] if 'sr_select' in data else "48000"
                    return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select
                    
        else:
            return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"
    else:
        return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000"


def make_pseudo_stereo (filename, sr_select, pan, delay):
    if pan:
        temp = AudioSegment.from_wav(filename)
        if sr_select != "32000":
            temp = temp.set_frame_rate(int(sr_select))
        left = temp.pan(-0.5) - 5
        right = temp.pan(0.6) - 5
        temp = left.overlay(right, position=5)
        temp.export(filename, format="wav")
    if delay:     
        waveform, sample_rate = torchaudio.load(filename) # load mono WAV file
        delay_seconds = 0.01 # set delay 10ms
        delay_samples = int(delay_seconds * sample_rate) # Calculating delay value in number of samples
        stereo_waveform = torch.stack([waveform[0], torch.cat((torch.zeros(delay_samples), waveform[0][:-delay_samples]))]) # Generate a stereo file with original mono audio and delayed version
        torchaudio.save(filename, stereo_waveform, sample_rate)
    return


def normalize_audio(audio_data):
    audio_data = audio_data.astype(np.float32)
    max_value = np.max(np.abs(audio_data))
    audio_data /= max_value
    return audio_data


def load_diffusion():
    global MBD
    if MBD is None:
        print("loading MBD")
        MBD = MultiBandDiffusion.get_mbd_musicgen()


def unload_diffusion():
    global MBD
    if MBD is not None:
        print("unloading MBD")
        MBD = None


def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=False, **gen_kwargs):
    if gen_type == "music":
        maximum_size = 29.5
    elif gen_type == "audio":
        maximum_size = 9.5
    cut_size = 0
    input_length = 0
    sampleP = None
    if sample is not None:
        globalSR, sampleM = sample[0], sample[1]
        sampleM = normalize_audio(sampleM)
        sampleM = torch.from_numpy(sampleM).t()
        if sampleM.dim() == 1:
            sampleM = sampleM.unsqueeze(0)
        sample_length = sampleM.shape[sampleM.dim() - 1] / globalSR
        if trim_start >= sample_length:
            trim_start = sample_length - 0.5
        if trim_end >= sample_length:
            trim_end = sample_length - 0.5
        if trim_start + trim_end >= sample_length:
            tmp = sample_length - 0.5
            trim_start = tmp / 2
            trim_end = tmp / 2
        sampleM = sampleM[..., int(globalSR * trim_start):int(globalSR * (sample_length - trim_end))]
        sample_length = sample_length - (trim_start + trim_end)
        if sample_length > maximum_size:
            cut_size = sample_length - maximum_size
            sampleP = sampleM[..., :int(globalSR * cut_size)]
            sampleM = sampleM[..., int(globalSR * cut_size):]
        if sample_length >= duration:
            duration = sample_length + 0.5
        input_length = sample_length
    global MODEL
    MODEL.set_generation_params(duration=(duration - cut_size), **gen_kwargs)
    print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies], [None if sample is None else (sample[0], sample[1].shape)])
    be = time.time()
    processed_melodies = []
    if gen_type == "music":
        target_sr = 32000
    elif gen_type == "audio":
        target_sr = 16000
    target_ac = 1

    for melody in melodies:
        if melody is None:
            processed_melodies.append(None)
        else:
            sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
            if melody.dim() == 1:
                melody = melody[None]
            melody = melody[..., :int(sr * duration)]
            melody = convert_audio(melody, sr, target_sr, target_ac)
            processed_melodies.append(melody)

    if sample is not None:
        if sampleP is None:
            if gen_type == "music":
                outputs = MODEL.generate_continuation(
                    prompt=sampleM,
                    prompt_sample_rate=globalSR,
                    descriptions=texts,
                    progress=progress,
                    return_tokens=USE_DIFFUSION
                )
            elif gen_type == "audio":
                outputs = MODEL.generate_continuation(
                    prompt=sampleM,
                    prompt_sample_rate=globalSR,
                    descriptions=texts,
                    progress=progress
                )
        else:
            if sampleP.dim() > 1:
                sampleP = convert_audio(sampleP, globalSR, target_sr, target_ac)
            sampleP = sampleP.to(MODEL.device).float().unsqueeze(0)
            if gen_type == "music":
                outputs = MODEL.generate_continuation(
                    prompt=sampleM,
                    prompt_sample_rate=globalSR,
                    descriptions=texts,
                    progress=progress,
                    return_tokens=USE_DIFFUSION
                )
            elif gen_type == "audio":
                outputs = MODEL.generate_continuation(
                    prompt=sampleM,
                    prompt_sample_rate=globalSR,
                    descriptions=texts,
                    progress=progress
                )
            outputs = torch.cat([sampleP, outputs], 2)
            
    elif any(m is not None for m in processed_melodies):
        if gen_type == "music":
            outputs = MODEL.generate_with_chroma(
                descriptions=texts,
                melody_wavs=processed_melodies,
                melody_sample_rate=target_sr,
                progress=progress,
                return_tokens=USE_DIFFUSION
            )
        elif gen_type == "audio":
            outputs = MODEL.generate_with_chroma(
                descriptions=texts,
                melody_wavs=processed_melodies,
                melody_sample_rate=target_sr,
                progress=progress
            )
    else:
        if gen_type == "music":
            outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
        elif gen_type == "audio":
            outputs = MODEL.generate(texts, progress=progress)

    if USE_DIFFUSION:
        print("outputs: " + str(outputs))
        outputs_diffusion = MBD.tokens_to_wav(outputs[1])
        outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
    outputs = outputs.detach().cpu().float()
    backups = outputs
    if channel == "stereo":
        outputs = convert_audio(outputs, target_sr, int(sr_select), 2)
    elif channel == "mono" and sr_select != "32000":
        outputs = convert_audio(outputs, target_sr, int(sr_select), 1)
    out_files = []
    out_audios = []
    out_backup = []
    for output in outputs:
        with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
            audio_write(
                file.name, output, (MODEL.sample_rate if channel == "stereo effect" else int(sr_select)), strategy="loudness",
                loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)

            if channel == "stereo effect":
                make_pseudo_stereo(file.name, sr_select, pan=True, delay=True);

            out_files.append(pool.submit(make_waveform, file.name, bg_image=image, bg_color=background, bars_color=(bar1, bar2), fg_alpha=1.0, bar_count=75, height=height, width=width))
            out_audios.append(file.name)
            file_cleaner.add(file.name)
            print(f'wav: {file.name}')
    for backup in backups:
        with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
            audio_write(
                file.name, backup, MODEL.sample_rate, strategy="loudness",
                loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
            out_backup.append(file.name)
            file_cleaner.add(file.name)
    res = [out_file.result() for out_file in out_files]
    res_audio = out_audios
    res_backup = out_backup
    for file in res:
        file_cleaner.add(file)
        print(f'video: {file}')
    print("batch finished", len(texts), time.time() - be)
    print("Tempfiles currently stored: ", len(file_cleaner.files))
    if MOVE_TO_CPU:
        MODEL.to('cpu')
    if UNLOAD_MODEL:
        MODEL = None
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    return res, res_audio, res_backup, input_length


def predict_batched(texts, melodies):
    max_text_length = 512
    texts = [text[:max_text_length] for text in texts]
    load_model('melody')
    res = _do_predictions(texts, melodies, BATCHED_DURATION)
    return res


def add_tags(filename, tags): 
    json_string = None

    data = {
        "global_prompt": tags[0],
        "bpm": tags[1],
        "key": tags[2],
        "scale": tags[3],
        "texts": tags[4],
        "duration": tags[5],
        "overlap": tags[6],
        "seed": tags[7],
        "audio_mode": tags[8],
        "input_length": tags[9],
        "channel": tags[10],
        "sr_select": tags[11],
        "model": tags[12],
        "custom_model": tags[13],
        "decoder": tags[14],
        "topk": tags[15],  
        "topp": tags[16],
        "temperature": tags[17],
        "cfg_coef": tags[18],
        "generator": tags[19],
        "version": version
        }

    json_string = json.dumps(data)

    if os.path.exists(filename):
        with taglib.File(filename, save_on_exit=True) as song:
            song.tags = {'COMMENT': json_string }

    json_file = open(tags[7] + '.json', 'w')
    json_file.write(json_string)
    json_file.close()

    return json_file.name;


def save_outputs(mp4, wav_tmp, tags, gen_type):
    # mp4: .mp4 file name in root running folder of app.py    
    # wav_tmp: temporary wav file located in %TEMP% folder
    # seed - used seed 
    # exanple BgnJtr4Pn1AJ.mp4,  C:\Users\Alex\AppData\Local\Temp\tmp4ermrebs.wav,  195123182343465
    # procedure read generated .mp4 and wav files, rename it by using seed as name, 
    # and will store it to ./output/today_date/wav and  ./output/today_date/mp4 folders. 
    # if file with same seed number already exist its make postfix in name like seed(n) 
    # where is n - consiqunce number 1-2-3-4 and so on
    # then we store generated mp4 and wav into destination folders.     

    current_date = datetime.now().strftime("%Y%m%d")
    wav_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'wav')
    mp4_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'mp4')
    json_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'json')
    os.makedirs(wav_directory, exist_ok=True)
    os.makedirs(mp4_directory, exist_ok=True)
    os.makedirs(json_directory, exist_ok=True)

    filename = str(tags[7]) + '.wav'
    target = os.path.join(wav_directory, filename)
    counter = 1
    while os.path.exists(target):
        filename = str(tags[7]) + f'({counter})' + '.wav'
        target = os.path.join(wav_directory, filename)
        counter += 1

    shutil.copyfile(wav_tmp, target); # make copy of original file
    json_file = add_tags(target, tags);
    
    wav_target=target;
    target=target.replace('wav', 'mp4');
    mp4_target=target;
    
    mp4=r'./' +mp4;    
    shutil.copyfile(mp4, target); # make copy of original file  
    _ = add_tags(target, tags);

    target=target.replace('mp4', 'json'); # change the extension to json
    json_target=target; # store the json target

    with open(target, 'w') as f: # open a writable file object
        shutil.copyfile(json_file, target); # make copy of original file
    
    os.remove(json_file)

    return wav_target, mp4_target, json_target;


def clear_cash():
    # delete all temporary files genegated my system
    current_date = datetime.now().date()
    current_directory = os.getcwd()
    files = glob.glob(os.path.join(current_directory, '*.mp4'))
    for file in files:
        creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
        if creation_date == current_date:
            os.remove(file)

    temp_directory = os.environ.get('TEMP')
    files = glob.glob(os.path.join(temp_directory, 'tmp*.mp4'))
    for file in files:
        creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
        if creation_date == current_date:
            os.remove(file)
   
    files = glob.glob(os.path.join(temp_directory, 'tmp*.wav'))
    for file in files:
        creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
        if creation_date == current_date:
            os.remove(file)

    files = glob.glob(os.path.join(temp_directory, 'tmp*.png'))
    for file in files:
        creation_date = datetime.fromtimestamp(os.path.getctime(file)).date()
        if creation_date == current_date:
            os.remove(file)
    return


def s2t(seconds, seconds2):
    # convert seconds to time format
    # seconds - time in seconds
    # return time in format 00:00
    m, s = divmod(seconds, 60)
    m2, s2 = divmod(seconds2, 60)
    if seconds != 0 and seconds < seconds2:
        s = s + 1
    return ("%02d:%02d - %02d:%02d" % (m, s, m2, s2))


def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9):
    # calculate the time of generation
    # overlap - overlap in seconds
    # d0-d9 - drag
    # return time in seconds
    d_amount = [int(d0), int(d1), int(d2), int(d3), int(d4), int(d5), int(d6), int(d7), int(d8), int(d9)]
    calc = []
    tracks = []
    time = 0
    s = s - 1
    max_time = duration
    max_limit = 0
    if gen_type == "music":
        max_limit = 30
    elif gen_type == "audio":
        max_limit = 10
    track_add = max_limit - overlap
    tracks.append(max_limit + ((d_amount[0] - 1) * track_add))
    for i in range(1, 10):
        tracks.append(d_amount[i] * track_add)
    
    if tracks[0] >= max_time or s == 0:
        calc.append(s2t(time, max_time))
        time = max_time
    else:
        calc.append(s2t(time, tracks[0]))
        time = tracks[0]

    for i in range(1, 10):
        if time + tracks[i] >= max_time or i == s:
            calc.append(s2t(time, max_time))
            time = max_time
        else:
            calc.append(s2t(time, time + tracks[i]))
            time = time + tracks[i]
    
    return calc[0], calc[1], calc[2], calc[3], calc[4], calc[5], calc[6], calc[7], calc[8], calc[9]


def predict_full(gen_type, model, decoder, custom_model, prompt_amount, struc_prompt, bpm, key, scale, global_prompt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select, progress=gr.Progress()):
    global INTERRUPTING
    global USE_DIFFUSION
    INTERRUPTING = False

    if gen_type == "audio":
        custom_model = None
        custom_model_shrt = "none"
    elif gen_type == "music":
        custom_model_shrt = custom_model
        custom_model = "models/" + custom_model

    if temperature < 0:
        raise gr.Error("Temperature must be >= 0.")
    if topk < 0:
        raise gr.Error("Topk must be non-negative.")
    if topp < 0:
        raise gr.Error("Topp must be non-negative.")

    if trim_start < 0:
        trim_start = 0
    if trim_end < 0:
        trim_end = 0

    topk = int(topk)

    if decoder == "MultiBand_Diffusion":
        USE_DIFFUSION = True
        load_diffusion()
    else:
        USE_DIFFUSION = False
        unload_diffusion()

    if gen_type == "music":
        model_shrt = model
        model = "GrandaddyShmax/musicgen-" + model
    elif gen_type == "audio":
        model_shrt = model
        model = "GrandaddyShmax/audiogen-" + model

    if MODEL is None or MODEL.name != (model):
        load_model(model, custom_model, gen_type)
    else:
        if MOVE_TO_CPU:
            MODEL.to('cuda')

    if seed < 0:
        seed = random.randint(0, 0xffff_ffff_ffff)
    torch.manual_seed(seed)

    def _progress(generated, to_generate):
        progress((min(generated, to_generate), to_generate))
        if INTERRUPTING:
            raise gr.Error("Interrupted.")
    MODEL.set_custom_progress_callback(_progress)

    audio_mode = "none"
    melody = None
    sample = None
    if audio:
      audio_mode = mode
      if mode == "sample":
          sample = audio
      elif mode == "melody":
          melody = audio

    custom_model_shrt = "none" if model != "GrandaddyShmax/musicgen-custom" else custom_model_shrt

    text_cat = [p0, p1, p2, p3, p4, p5, p6, p7, p8, p9]
    drag_cat = [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9]
    texts = []
    raw_texts = []
    ind = 0
    ind2 = 0
    while ind < prompt_amount:
        for ind2 in range(int(drag_cat[ind])):
            if not struc_prompt:
                texts.append(text_cat[ind])
                global_prompt = "none"
                bpm = "none"
                key = "none"
                scale = "none"
                raw_texts.append(text_cat[ind])
            else:
                if gen_type == "music":
                    bpm_str = str(bpm) + " bpm"
                    key_str = ", " + str(key) + " " + str(scale)
                    global_str = (", " + str(global_prompt)) if str(global_prompt) != "" else ""
                elif gen_type == "audio":
                    bpm_str = ""
                    key_str = ""
                    global_str = (str(global_prompt)) if str(global_prompt) != "" else ""
                texts_str = (", " + str(text_cat[ind])) if str(text_cat[ind]) != "" else ""
                texts.append(bpm_str + key_str + global_str + texts_str)
                raw_texts.append(text_cat[ind])
        ind2 = 0
        ind = ind + 1

    outs, outs_audio, outs_backup, input_length = _do_predictions(
        gen_type, [texts], [melody], sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=True,
        top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, extend_stride=MODEL.max_duration-overlap)
    tags = [str(global_prompt), str(bpm), str(key), str(scale), str(raw_texts), str(duration), str(overlap), str(seed), str(audio_mode), str(input_length), str(channel), str(sr_select), str(model_shrt), str(custom_model_shrt), str(decoder), str(topk), str(topp), str(temperature), str(cfg_coef), str(gen_type)]
    wav_target, mp4_target, json_target = save_outputs(outs[0], outs_audio[0], tags, gen_type);
    # Removes the temporary files.
    for out in outs:
        os.remove(out)
    for out in outs_audio:
        os.remove(out)

    return mp4_target, wav_target, outs_backup[0], [mp4_target, wav_target, json_target], seed


max_textboxes = 10


#def get_available_models():
    #return sorted([re.sub('.pt$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('.pt')])


def get_available_folders():
    models_dir = "models"
    folders = [f for f in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, f))]
    return sorted(folders)


def toggle_audio_src(choice):
    if choice == "mic":
        return gr.update(source="microphone", value=None, label="Microphone")
    else:
        return gr.update(source="upload", value=None, label="File")


def ui_full(launch_kwargs):
    with gr.Blocks(title='AudioCraft Plus', theme=theme) as interface:
        gr.Markdown(
            """
            # AudioCraft Plus - v2.0.1

            ### An All-in-One AudioCraft WebUI

            Thanks to: facebookresearch, Camenduru, rkfg, oobabooga, AlexHK and GrandaddyShmax
            """
        )
        with gr.Tab("MusicGen"):
            gr.Markdown(
                """
                ### MusicGen
                """
            )
            with gr.Row():
                with gr.Column():
                    with gr.Tab("Generation"):
                        with gr.Accordion("Structure Prompts", open=False):
                            with gr.Column():
                                with gr.Row():
                                    struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
                                    bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0)
                                    key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True)
                                    scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True)
                                with gr.Row():
                                    global_prompt = gr.Text(label="Global Prompt", interactive=True, scale=3)
                        with gr.Row():
                            s = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
                            #s_mode = gr.Radio(["segmentation", "batch"], value="segmentation", interactive=True, scale=1, label="Generation Mode")
                        with gr.Column():
                            textboxes = []
                            prompts = []
                            repeats = []
                            calcs = []
                            with gr.Row():
                                text0 = gr.Text(label="Input Text", interactive=True, scale=4)
                                prompts.append(text0)
                                drag0 = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
                                repeats.append(drag0)
                                calc0 = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
                                calcs.append(calc0)
                            for i in range(max_textboxes):
                                with gr.Row(visible=False) as t:
                                    text = gr.Text(label="Input Text", interactive=True, scale=3)
                                    repeat = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
                                    calc = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
                                textboxes.append(t)
                                prompts.append(text)
                                repeats.append(repeat)
                                calcs.append(calc)
                            to_calc = gr.Button("Calculate Timings", variant="secondary")
                        with gr.Row():
                            duration = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
                        with gr.Row():
                            overlap = gr.Slider(minimum=1, maximum=29, value=12, step=1, label="Overlap", interactive=True)
                        with gr.Row():
                            seed = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
                            gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed], queue=False)
                            reuse_seed = gr.Button('\u267b\ufe0f', scale=1)

                    with gr.Tab("Audio"):
                        with gr.Row():
                            with gr.Column():
                                input_type = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
                                mode = gr.Radio(["melody", "sample"], label="Input Audio Mode (optional)", value="sample", interactive=True)
                                with gr.Row():
                                    trim_start = gr.Number(label="Trim Start", value=0, interactive=True)
                                    trim_end = gr.Number(label="Trim End", value=0, interactive=True)
                            audio = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)

                    with gr.Tab("Customization"):
                        with gr.Row():
                            with gr.Column():
                                background = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
                                bar1 = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
                                bar2 = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
                            with gr.Column():
                                image = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
                                with gr.Row():
                                    height = gr.Number(label="Height", value=512, interactive=True)
                                    width = gr.Number(label="Width", value=768, interactive=True)

                    with gr.Tab("Settings"):
                        with gr.Row():
                            channel = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
                            sr_select = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
                        with gr.Row():
                            model = gr.Radio(["melody", "small", "medium", "large", "custom"], label="Model", value="large", interactive=True, scale=1)
                            with gr.Column():
                                dropdown = gr.Dropdown(choices=get_available_folders(), value=("No models found" if len(get_available_folders()) < 1 else get_available_folders()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True)
                                ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_folders()}, 'refresh-button')
                        with gr.Row():
                            decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True)
                        with gr.Row():
                            topk = gr.Number(label="Top-k", value=250, interactive=True)
                            topp = gr.Number(label="Top-p", value=0, interactive=True)
                            temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
                            cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
                    with gr.Row():
                        submit = gr.Button("Generate", variant="primary")
                        # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
                        _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
                with gr.Column() as c:
                    with gr.Tab("Output"):
                        output = gr.Video(label="Generated Music", scale=0)
                        with gr.Row():
                            audio_only = gr.Audio(type="numpy", label="Audio Only", interactive=False)
                            backup_only = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
                            send_audio = gr.Button("Send to Input Audio")
                        seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
                        download = gr.File(label="Generated Files", interactive=False)
                    with gr.Tab("Wiki"):
                        gr.Markdown(
                            """
                            - **[Generate (button)]:**  
                            Generates the music with the given settings and prompts.

                            - **[Interrupt (button)]:**  
                            Stops the music generation as soon as it can, providing an incomplete output.

                            ---

                            ### Generation Tab:

                            #### Structure Prompts:

                            This feature helps reduce repetetive prompts by allowing you to set global prompts  
                            that will be used for all prompt segments.

                            - **[Structure Prompts (checkbox)]:**  
                            Enable/Disable the structure prompts feature.

                            - **[BPM (number)]:**  
                            Beats per minute of the generated music.

                            - **[Key (dropdown)]:**  
                            The key of the generated music.

                            - **[Scale (dropdown)]:**  
                            The scale of the generated music.

                            - **[Global Prompt (text)]:**  
                            Here write the prompt that you wish to be used for all prompt segments.

                            #### Multi-Prompt: 
                            
                            This feature allows you to control the music, adding variation to different time segments.  
                            You have up to 10 prompt segments. the first prompt will always be 30s long  
                            the other prompts will be [30s - overlap].  
                            for example if the overlap is 10s, each prompt segment will be 20s.

                            - **[Prompt Segments (number)]:**  
                            Amount of unique prompt to generate throughout the music generation.

                            - **[Prompt/Input Text (prompt)]:**  
                            Here describe the music you wish the model to generate.

                            - **[Repeat (number)]:**  
                            Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).

                            - **[Time (text)]:**  
                            The time of the prompt segment.

                            - **[Calculate Timings (button)]:**  
                            Calculates the timings of the prompt segments.

                            - **[Duration (number)]:**  
                            How long you want the generated music to be (in seconds).

                            - **[Overlap (number)]:**  
                            How much each new segment will reference the previous segment (in seconds).  
                            For example, if you choose 20s: Each new segment after the first one will reference the previous segment 20s  
                            and will generate only 10s of new music. The model can only process 30s of music.

                            - **[Seed (number)]:**  
                            Your generated music id. If you wish to generate the exact same music,  
                            place the exact seed with the exact prompts  
                            (This way you can also extend specific song that was generated short).

                            - **[Random Seed (button)]:**  
                            Gives "-1" as a seed, which counts as a random seed.

                            - **[Copy Previous Seed (button)]:**  
                            Copies the seed from the output seed (if you don't feel like doing it manualy).

                            ---

                            ### Audio Tab:

                            - **[Input Type (selection)]:**  
                            `File` mode allows you to upload an audio file to use as input  
                            `Mic` mode allows you to use your microphone as input

                            - **[Input Audio Mode (selection)]:**  
                            `Melody` mode only works with the melody model: it conditions the music generation to reference the melody  
                            `Sample` mode works with any model: it gives a music sample to the model to generate its continuation.

                            - **[Trim Start and Trim End (numbers)]:**  
                            `Trim Start` set how much you'd like to trim the input audio from the start  
                            `Trim End` same as the above but from the end

                            - **[Input Audio (audio file)]:**  
                            Input here the audio you wish to use with "melody" or "sample" mode.

                            ---

                            ### Customization Tab:

                            - **[Background Color (color)]:**  
                            Works only if you don't upload image. Color of the background of the waveform.

                            - **[Bar Color Start (color)]:**  
                            First color of the waveform bars.

                            - **[Bar Color End (color)]:**  
                            Second color of the waveform bars.

                            - **[Background Image (image)]:**  
                            Background image that you wish to be attached to the generated video along with the waveform.

                            - **[Height and Width (numbers)]:**  
                            Output video resolution, only works with image.  
                            (minimum height and width is 256).
                            
                            ---

                            ### Settings Tab:

                            - **[Output Audio Channels (selection)]:**  
                            With this you can select the amount of channels that you wish for your output audio.  
                            `mono` is a straightforward single channel audio  
                            `stereo` is a dual channel audio but it will sound more or less like mono  
                            `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.

                            - **[Output Audio Sample Rate (dropdown)]:**  
                            The output audio sample rate, the model default is 32000.

                            - **[Model (selection)]:**  
                            Here you can choose which model you wish to use:  
                            `melody` model is based on the medium model with a unique feature that lets you use melody conditioning  
                            `small` model is trained on 300M parameters  
                            `medium` model is trained on 1.5B parameters  
                            `large` model is trained on 3.3B parameters  
                            `custom` model runs the custom model that you provided.

                            - **[Custom Model (selection)]:**  
                            This dropdown will show you models that are placed in the `models` folder  
                            you must select `custom` in the model options in order to use it.

                            - **[Refresh (button)]:**  
                            Refreshes the dropdown list for custom model.

                            - **[Decoder (selection)]:**  
                            Choose here the decoder that you wish to use:  
                            `Default` is the default decoder  
                            `MultiBand_Diffusion` is a decoder that uses diffusion to generate the audio.

                            - **[Top-k (number)]:**  
                            is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.

                            - **[Top-p (number)]:**  
                            also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
                            
                            - **[Temperature (number)]:**  
                            is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.

                            - **[Classifier Free Guidance (number)]:**  
                            refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
                            """
                        )
        with gr.Tab("AudioGen"):
            gr.Markdown(
                """
                ### AudioGen
                """
            )
            with gr.Row():
                with gr.Column():
                    with gr.Tab("Generation"):
                        with gr.Accordion("Structure Prompts", open=False):
                            with gr.Row():
                                struc_prompts_a = gr.Checkbox(label="Enable", value=False, interactive=True, container=False)
                                global_prompt_a = gr.Text(label="Global Prompt", interactive=True, scale=3)
                        with gr.Row():
                            s_a = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2)
                        with gr.Column():
                            textboxes_a = []
                            prompts_a = []
                            repeats_a = []
                            calcs_a = []
                            with gr.Row():
                                text0_a = gr.Text(label="Input Text", interactive=True, scale=4)
                                prompts_a.append(text0_a)
                                drag0_a = gr.Number(label="Repeat", value=1, interactive=True, scale=1)
                                repeats_a.append(drag0_a)
                                calc0_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
                                calcs_a.append(calc0_a)
                            for i in range(max_textboxes):
                                with gr.Row(visible=False) as t_a:
                                    text_a = gr.Text(label="Input Text", interactive=True, scale=3)
                                    repeat_a = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1)
                                    calc_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time")
                                textboxes_a.append(t_a)
                                prompts_a.append(text_a)
                                repeats_a.append(repeat_a)
                                calcs_a.append(calc_a)
                            to_calc_a = gr.Button("Calculate Timings", variant="secondary")
                        with gr.Row():
                            duration_a = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True)
                        with gr.Row():
                            overlap_a = gr.Slider(minimum=1, maximum=9, value=2, step=1, label="Overlap", interactive=True)
                        with gr.Row():
                            seed_a = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True)
                            gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed_a], queue=False)
                            reuse_seed_a = gr.Button('\u267b\ufe0f', scale=1)

                    with gr.Tab("Audio"):
                        with gr.Row():
                            with gr.Column():
                                input_type_a = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True)
                                mode_a = gr.Radio(["sample"], label="Input Audio Mode (optional)", value="sample", interactive=False, visible=False)
                                with gr.Row():
                                    trim_start_a = gr.Number(label="Trim Start", value=0, interactive=True)
                                    trim_end_a = gr.Number(label="Trim End", value=0, interactive=True)
                            audio_a = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True)

                    with gr.Tab("Customization"):
                        with gr.Row():
                            with gr.Column():
                                background_a = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0)
                                bar1_a = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0)
                                bar2_a = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0)
                            with gr.Column():
                                image_a = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4)
                                with gr.Row():
                                    height_a = gr.Number(label="Height", value=512, interactive=True)
                                    width_a = gr.Number(label="Width", value=768, interactive=True)

                    with gr.Tab("Settings"):
                        with gr.Row():
                            channel_a = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1)
                            sr_select_a = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True)
                        with gr.Row():
                            model_a = gr.Radio(["medium"], label="Model", value="medium", interactive=False, visible=False)
                            decoder_a = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False, visible=False)
                        with gr.Row():
                            topk_a = gr.Number(label="Top-k", value=250, interactive=True)
                            topp_a = gr.Number(label="Top-p", value=0, interactive=True)
                            temperature_a = gr.Number(label="Temperature", value=1.0, interactive=True)
                            cfg_coef_a = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
                    with gr.Row():
                        submit_a = gr.Button("Generate", variant="primary")
                        _ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
                with gr.Column():
                    with gr.Tab("Output"):
                        output_a = gr.Video(label="Generated Audio", scale=0)
                        with gr.Row():
                            audio_only_a = gr.Audio(type="numpy", label="Audio Only", interactive=False)
                            backup_only_a = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False)
                            send_audio_a = gr.Button("Send to Input Audio")
                        seed_used_a = gr.Number(label='Seed used', value=-1, interactive=False)
                        download_a = gr.File(label="Generated Files", interactive=False)
                    with gr.Tab("Wiki"):
                        gr.Markdown(
                            """
                            - **[Generate (button)]:**  
                            Generates the audio with the given settings and prompts.

                            - **[Interrupt (button)]:**  
                            Stops the audio generation as soon as it can, providing an incomplete output.

                            ---

                            ### Generation Tab:

                            #### Structure Prompts:

                            This feature helps reduce repetetive prompts by allowing you to set global prompts  
                            that will be used for all prompt segments.

                            - **[Structure Prompts (checkbox)]:**  
                            Enable/Disable the structure prompts feature.

                            - **[Global Prompt (text)]:**  
                            Here write the prompt that you wish to be used for all prompt segments.

                            #### Multi-Prompt: 
                            
                            This feature allows you to control the audio, adding variation to different time segments.  
                            You have up to 10 prompt segments. the first prompt will always be 10s long  
                            the other prompts will be [10s - overlap].  
                            for example if the overlap is 2s, each prompt segment will be 8s.

                            - **[Prompt Segments (number)]:**  
                            Amount of unique prompt to generate throughout the audio generation.

                            - **[Prompt/Input Text (prompt)]:**  
                            Here describe the audio you wish the model to generate.

                            - **[Repeat (number)]:**  
                            Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt).

                            - **[Time (text)]:**  
                            The time of the prompt segment.

                            - **[Calculate Timings (button)]:**  
                            Calculates the timings of the prompt segments.

                            - **[Duration (number)]:**  
                            How long you want the generated audio to be (in seconds).

                            - **[Overlap (number)]:**  
                            How much each new segment will reference the previous segment (in seconds).  
                            For example, if you choose 2s: Each new segment after the first one will reference the previous segment 2s  
                            and will generate only 8s of new audio. The model can only process 10s of music.

                            - **[Seed (number)]:**  
                            Your generated audio id. If you wish to generate the exact same audio,  
                            place the exact seed with the exact prompts  
                            (This way you can also extend specific song that was generated short).

                            - **[Random Seed (button)]:**  
                            Gives "-1" as a seed, which counts as a random seed.

                            - **[Copy Previous Seed (button)]:**  
                            Copies the seed from the output seed (if you don't feel like doing it manualy).

                            ---

                            ### Audio Tab:

                            - **[Input Type (selection)]:**  
                            `File` mode allows you to upload an audio file to use as input  
                            `Mic` mode allows you to use your microphone as input

                            - **[Trim Start and Trim End (numbers)]:**  
                            `Trim Start` set how much you'd like to trim the input audio from the start  
                            `Trim End` same as the above but from the end

                            - **[Input Audio (audio file)]:**  
                            Input here the audio you wish to use.

                            ---

                            ### Customization Tab:

                            - **[Background Color (color)]:**  
                            Works only if you don't upload image. Color of the background of the waveform.

                            - **[Bar Color Start (color)]:**  
                            First color of the waveform bars.

                            - **[Bar Color End (color)]:**  
                            Second color of the waveform bars.

                            - **[Background Image (image)]:**  
                            Background image that you wish to be attached to the generated video along with the waveform.

                            - **[Height and Width (numbers)]:**  
                            Output video resolution, only works with image.  
                            (minimum height and width is 256).
                            
                            ---

                            ### Settings Tab:

                            - **[Output Audio Channels (selection)]:**  
                            With this you can select the amount of channels that you wish for your output audio.  
                            `mono` is a straightforward single channel audio  
                            `stereo` is a dual channel audio but it will sound more or less like mono  
                            `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio.

                            - **[Output Audio Sample Rate (dropdown)]:**  
                            The output audio sample rate, the model default is 32000.

                            - **[Top-k (number)]:**  
                            is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music.

                            - **[Top-p (number)]:**  
                            also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities.
                            
                            - **[Temperature (number)]:**  
                            is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music.

                            - **[Classifier Free Guidance (number)]:**  
                            refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture.
                            """
                        )
        with gr.Tab("Audio Info"):
            gr.Markdown(
                """
                ### Audio Info
                """
            )
            with gr.Row():
                with gr.Column():
                    in_audio = gr.File(type="file", label="Input Any Audio", interactive=True)
                    with gr.Row():
                        send_gen = gr.Button("Send to MusicGen", variant="primary")
                        send_gen_a = gr.Button("Send to AudioGen", variant="primary")
                with gr.Column():
                    info = gr.Textbox(label="Audio Info", lines=10, interactive=False)
        with gr.Tab("Changelog"):
            gr.Markdown(
                            """
                            ## Changelog:

                            ### v2.0.1

                            - Changed custom model loading to support the official trained models

                            - Additional changes from the main facebookresearch repo



                            ### v2.0.0a

                            - Forgot to move all the update to app.py from temp2.py... oops



                            ### v2.0.0

                            - Changed name from MusicGen+ to AudioCraft Plus
                            
                            - Complete overhaul of the repo "backend" with the latest changes from the main facebookresearch repo

                            - Added a new decoder: MultiBand_Diffusion

                            - Added AudioGen: a new tab for generating audio



                            ### v1.2.8c

                            - Implemented Reverse compatibility for audio info tab with previous versions



                            ### v1.2.8b

                            - Fixed the error when loading default models



                            ### v1.2.8a

                            - Adapted Audio info tab to work with the new structure prompts feature

                            - Now custom models actually work, make sure you select the correct base model



                            ### v1.2.8

                            - Now you will also recieve json file with metadata of generated audio

                            - Added error messages in Audio Info tab

                            - Added structure prompts: you can select bpm, key and global prompt for all prompts

                            - Added time display next to each prompt, can be calculated with "Calculate Timings" button



                            ### v1.2.7

                            - When sending generated audio to Input Audio, it will send a backup audio with default settings  
                            (best for continuos generation)

                            - Added Metadata to generated audio (Thanks to AlexHK ♥)

                            - Added Audio Info tab that will display the metadata of the input audio

                            - Added "send to Text2Audio" button in Audio Info tab

                            - Generated audio is now stored in the "output" folder (Thanks to AlexHK ♥)

                            - Added an output area with generated files and download buttons

                            - Enhanced Stereo effect (Thanks to AlexHK ♥)



                            ### v1.2.6

                            - Added option to generate in stereo (instead of only mono)

                            - Added dropdown for selecting output sample rate (model default is 32000)



                            ### v1.2.5a

                            - Added file cleaner (This comes from the main facebookresearch repo)

                            - Reorganized a little, moved audio to a seperate tab



                            ### v1.2.5

                            - Gave a unique lime theme to the webui
                            
                            - Added additional output for audio only

                            - Added button to send generated audio to Input Audio

                            - Added option to trim Input Audio



                            ### v1.2.4

                            - Added mic input (This comes from the main facebookresearch repo)



                            ### v1.2.3

                            - Added option to change video size to fit the image you upload



                            ### v1.2.2

                            - Added Wiki, Changelog and About tabs



                            ### v1.2.1

                            - Added tabs and organized the entire interface

                            - Added option to attach image to the output video

                            - Added option to load fine-tuned models (Yet to be tested)



                            ### v1.2.0

                            - Added Multi-Prompt



                            ### v1.1.3

                            - Added customization options for generated waveform



                            ### v1.1.2

                            - Removed sample length limit: now you can input audio of any length as music sample



                            ### v1.1.1

                            - Improved music sample audio quality when using music continuation



                            ### v1.1.0

                            - Rebuilt the repo on top of the latest structure of the main MusicGen repo
                            
                            - Improved Music continuation feature



                            ### v1.0.0 - Stable Version

                            - Added Music continuation
                            """
                        )
        with gr.Tab("About"):
            gen_type = gr.Text(value="music", interactive=False, visible=False)
            gen_type_a = gr.Text(value="audio", interactive=False, visible=False)
            gr.Markdown(
                            """
                            This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
                            presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
                            
                            ## MusicGen+ is an extended version of the original MusicGen by facebookresearch. 
                            
                            ### Repo: https://github.com/GrandaddyShmax/audiocraft_plus/tree/plus

                            ---
                            
                            ### This project was possible thanks to:

                            #### GrandaddyShmax - https://github.com/GrandaddyShmax

                            #### Camenduru - https://github.com/camenduru

                            #### rkfg - https://github.com/rkfg

                            #### oobabooga - https://github.com/oobabooga
                            
                            #### AlexHK - https://github.com/alanhk147
                            """
                        )

        send_gen.click(info_to_params, inputs=[in_audio], outputs=[decoder, struc_prompts, global_prompt, bpm, key, scale, model, dropdown, s, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select], queue=False)
        reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False)
        send_audio.click(fn=lambda x: x, inputs=[backup_only], outputs=[audio], queue=False)
        submit.click(predict_full, inputs=[gen_type, model, decoder, dropdown, s, struc_prompts, bpm, key, scale, global_prompt, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select], outputs=[output, audio_only, backup_only, download, seed_used])
        input_type.change(toggle_audio_src, input_type, [audio], queue=False, show_progress=False)
        to_calc.click(calc_time, inputs=[gen_type, s, duration, overlap, repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9]], outputs=[calcs[0], calcs[1], calcs[2], calcs[3], calcs[4], calcs[5], calcs[6], calcs[7], calcs[8], calcs[9]], queue=False)

        send_gen_a.click(info_to_params_a, inputs=[in_audio], outputs=[decoder_a, struc_prompts_a, global_prompt_a, s_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, channel_a, sr_select_a], queue=False)
        reuse_seed_a.click(fn=lambda x: x, inputs=[seed_used_a], outputs=[seed_a], queue=False)
        send_audio_a.click(fn=lambda x: x, inputs=[backup_only_a], outputs=[audio_a], queue=False)
        submit_a.click(predict_full, inputs=[gen_type_a, model_a, decoder_a, dropdown, s_a, struc_prompts_a, bpm, key, scale, global_prompt_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], audio_a, mode_a, trim_start_a, trim_end_a, duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, image_a, height_a, width_a, background_a, bar1_a, bar2_a, channel_a, sr_select_a], outputs=[output_a, audio_only_a, backup_only_a, download_a, seed_used_a])
        input_type_a.change(toggle_audio_src, input_type_a, [audio_a], queue=False, show_progress=False)
        to_calc_a.click(calc_time, inputs=[gen_type_a, s_a, duration_a, overlap_a, repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9]], outputs=[calcs_a[0], calcs_a[1], calcs_a[2], calcs_a[3], calcs_a[4], calcs_a[5], calcs_a[6], calcs_a[7], calcs_a[8], calcs_a[9]], queue=False)

        in_audio.change(get_audio_info, in_audio, outputs=[info])

        def variable_outputs(k):
            k = int(k) - 1
            return [gr.Textbox.update(visible=True)]*k + [gr.Textbox.update(visible=False)]*(max_textboxes-k)
        def get_size(image):
            if image is not None:
                img = Image.open(image)
                img_height = img.height
                img_width = img.width
                if (img_height%2) != 0:
                    img_height = img_height + 1
                if (img_width%2) != 0:
                    img_width = img_width + 1
                return img_height, img_width
            else:
                return 512, 768

        image.change(get_size, image, outputs=[height, width])
        image_a.change(get_size, image_a, outputs=[height_a, width_a])
        s.change(variable_outputs, s, textboxes)
        s_a.change(variable_outputs, s_a, textboxes_a)
        interface.queue().launch(**launch_kwargs)


def ui_batched(launch_kwargs):
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # MusicGen

            This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
            a simple and controllable model for music generation
            presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
            <br/>
            <a href="https://huggingface.co/spaces/facebook/MusicGen?duplicate=true"
                style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
            <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;"
                src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
            for longer sequences, more control and no queue.</p>
            """
        )
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    text = gr.Text(label="Describe your music", lines=2, interactive=True)
                    with gr.Column():
                        radio = gr.Radio(["file", "mic"], value="file",
                                         label="Condition on a melody (optional) File or Mic")
                        melody = gr.Audio(source="upload", type="numpy", label="File",
                                          interactive=True, elem_id="melody-input")
                with gr.Row():
                    submit = gr.Button("Generate")
            with gr.Column():
                output = gr.Video(label="Generated Music")
                audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
        submit.click(predict_batched, inputs=[text, melody],
                     outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE)
        radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
        gr.Examples(
            fn=predict_batched,
            examples=[
                [
                    "An 80s driving pop song with heavy drums and synth pads in the background",
                    "./assets/bach.mp3",
                ],
                [
                    "A cheerful country song with acoustic guitars",
                    "./assets/bolero_ravel.mp3",
                ],
                [
                    "90s rock song with electric guitar and heavy drums",
                    None,
                ],
                [
                    "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
                    "./assets/bach.mp3",
                ],
                [
                    "lofi slow bpm electro chill with organic samples",
                    None,
                ],
            ],
            inputs=[text, melody],
            outputs=[output]
        )
        gr.Markdown("""
        ### More details

        The model will generate 12 seconds of audio based on the description you provided.
        You can optionally provide a reference audio from which a broad melody will be extracted.
        The model will then try to follow both the description and melody provided.
        All samples are generated with the `melody` model.

        You can also use your own GPU or a Google Colab by following the instructions on our repo.

        See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
        for more details.
        """)

        demo.queue(max_size=8 * 4).launch(**launch_kwargs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--listen',
        type=str,
        default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
        help='IP to listen on for connections to Gradio',
    )
    parser.add_argument(
        '--username', type=str, default='', help='Username for authentication'
    )
    parser.add_argument(
        '--password', type=str, default='', help='Password for authentication'
    )
    parser.add_argument(
        '--server_port',
        type=int,
        default=0,
        help='Port to run the server listener on',
    )
    parser.add_argument(
        '--inbrowser', action='store_true', help='Open in browser'
    )
    parser.add_argument(
        '--share', action='store_true', help='Share the gradio UI'
    )
    parser.add_argument(
        '--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory'
    )

    parser.add_argument(
        '--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)'
    )

    parser.add_argument(
        '--cache', action='store_true', help='Cache models in RAM to quickly switch between them'
    )

    args = parser.parse_args()
    UNLOAD_MODEL = args.unload_model
    MOVE_TO_CPU = args.unload_to_cpu
    if args.cache:
        MODELS = {}

    launch_kwargs = {}
    launch_kwargs['server_name'] = args.listen

    if args.username and args.password:
        launch_kwargs['auth'] = (args.username, args.password)
    if args.server_port:
        launch_kwargs['server_port'] = args.server_port
    if args.inbrowser:
        launch_kwargs['inbrowser'] = args.inbrowser
    if args.share:
        launch_kwargs['share'] = args.share

    # Show the interface
    if IS_BATCHED:
        global USE_DIFFUSION
        USE_DIFFUSION = False
        ui_batched(launch_kwargs)
    else:
        ui_full(launch_kwargs)

================================================
FILE: audiocraft/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
AudioCraft is a general framework for training audio generative models.
At the moment we provide the training code for:

- [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
    text-to-music and melody+text autoregressive generative model.
    For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
    `audiocraft.models.musicgen.MusicGen`.
- [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
    text-to-general-audio generative model.
- [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
    neural audio codec which provides an excellent tokenizer for autoregressive language models.
    See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
- [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
    improves the perceived quality and reduces the artifacts coming from adversarial decoders.
"""

# flake8: noqa
from . import data, modules, models

__version__ = '1.0.0'


================================================
FILE: audiocraft/adversarial/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Adversarial losses and discriminator architectures."""

# flake8: noqa
from .discriminators import (
    MultiPeriodDiscriminator,
    MultiScaleDiscriminator,
    MultiScaleSTFTDiscriminator
)
from .losses import (
    AdversarialLoss,
    AdvLossType,
    get_adv_criterion,
    get_fake_criterion,
    get_real_criterion,
    FeatLossType,
    FeatureMatchingLoss
)


================================================
FILE: audiocraft/adversarial/discriminators/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# flake8: noqa
from .mpd import MultiPeriodDiscriminator
from .msd import MultiScaleDiscriminator
from .msstftd import MultiScaleSTFTDiscriminator


================================================
FILE: audiocraft/adversarial/discriminators/base.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
import typing as tp

import torch
import torch.nn as nn


FeatureMapType = tp.List[torch.Tensor]
LogitsType = torch.Tensor
MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]


class MultiDiscriminator(ABC, nn.Module):
    """Base implementation for discriminators composed of sub-discriminators acting at different scales.
    """
    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
        ...

    @property
    @abstractmethod
    def num_discriminators(self) -> int:
        """Number of discriminators.
        """
        ...


================================================
FILE: audiocraft/adversarial/discriminators/mpd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType


def get_padding(kernel_size: int, dilation: int = 1) -> int:
    return int((kernel_size * dilation - dilation) / 2)


class PeriodDiscriminator(nn.Module):
    """Period sub-discriminator.

    Args:
        period (int): Period between samples of audio.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        n_layers (int): Number of convolutional layers.
        kernel_sizes (list of int): Kernel sizes for convolutions.
        stride (int): Stride for convolutions.
        filters (int): Initial number of filters in convolutions.
        filters_scale (int): Multiplier of number of filters as we increase depth.
        max_filters (int): Maximum number of filters.
        norm (str): Normalization method.
        activation (str): Activation function.
        activation_params (dict): Parameters to provide to the activation function.
    """
    def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
                 n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
                 filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
                 activation_params: dict = {'negative_slope': 0.2}):
        super().__init__()
        self.period = period
        self.n_layers = n_layers
        self.activation = getattr(torch.nn, activation)(**activation_params)
        self.convs = nn.ModuleList()
        in_chs = in_channels
        for i in range(self.n_layers):
            out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
            eff_stride = 1 if i == self.n_layers - 1 else stride
            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
                                         padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
            in_chs = out_chs
        self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
                                    padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)

    def forward(self, x: torch.Tensor):
        fmap = []
        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0:  # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), 'reflect')
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for conv in self.convs:
            x = conv(x)
            x = self.activation(x)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        # x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiPeriodDiscriminator(MultiDiscriminator):
    """Multi-Period (MPD) Discriminator.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
        **kwargs: Additional args for `PeriodDiscriminator`
    """
    def __init__(self, in_channels: int = 1, out_channels: int = 1,
                 periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
        super().__init__()
        self.discriminators = nn.ModuleList([
            PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
        ])

    @property
    def num_discriminators(self):
        return len(self.discriminators)

    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
        logits = []
        fmaps = []
        for disc in self.discriminators:
            logit, fmap = disc(x)
            logits.append(logit)
            fmaps.append(fmap)
        return logits, fmaps


================================================
FILE: audiocraft/adversarial/discriminators/msd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import numpy as np
import torch
import torch.nn as nn

from ...modules import NormConv1d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType


class ScaleDiscriminator(nn.Module):
    """Waveform sub-discriminator.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
        filters (int): Number of initial filters for convolutions.
        max_filters (int): Maximum number of filters.
        downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
        inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
        groups (Sequence[int] or None): Groups for inner convolutions.
        strides (Sequence[int] or None): Strides for inner convolutions.
        paddings (Sequence[int] or None): Paddings for inner convolutions.
        norm (str): Normalization method.
        activation (str): Activation function.
        activation_params (dict): Parameters to provide to the activation function.
        pad (str): Padding for initial convolution.
        pad_params (dict): Parameters to provide to the padding module.
    """
    def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
                 filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
                 inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
                 strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
                 norm: str = 'weight_norm', activation: str = 'LeakyReLU',
                 activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
                 pad_params: dict = {}):
        super().__init__()
        assert len(kernel_sizes) == 2
        assert kernel_sizes[0] % 2 == 1
        assert kernel_sizes[1] % 2 == 1
        assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
        assert (groups is None or len(groups) == len(downsample_scales))
        assert (strides is None or len(strides) == len(downsample_scales))
        assert (paddings is None or len(paddings) == len(downsample_scales))
        self.activation = getattr(torch.nn, activation)(**activation_params)
        self.convs = nn.ModuleList()
        self.convs.append(
            nn.Sequential(
                getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
                NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
            )
        )

        in_chs = filters
        for i, downsample_scale in enumerate(downsample_scales):
            out_chs = min(in_chs * downsample_scale, max_filters)
            default_kernel_size = downsample_scale * 10 + 1
            default_stride = downsample_scale
            default_padding = (default_kernel_size - 1) // 2
            default_groups = in_chs // 4
            self.convs.append(
                NormConv1d(in_chs, out_chs,
                           kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
                           stride=strides[i] if strides else default_stride,
                           groups=groups[i] if groups else default_groups,
                           padding=paddings[i] if paddings else default_padding,
                           norm=norm))
            in_chs = out_chs

        out_chs = min(in_chs * 2, max_filters)
        self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
                                     padding=(kernel_sizes[0] - 1) // 2, norm=norm))
        self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
                                    padding=(kernel_sizes[1] - 1) // 2, norm=norm)

    def forward(self, x: torch.Tensor):
        fmap = []
        for layer in self.convs:
            x = layer(x)
            x = self.activation(x)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        # x = torch.flatten(x, 1, -1)
        return x, fmap


class MultiScaleDiscriminator(MultiDiscriminator):
    """Multi-Scale (MSD) Discriminator,

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        downsample_factor (int): Downsampling factor between the different scales.
        scale_norms (Sequence[str]): Normalization for each sub-discriminator.
        **kwargs: Additional args for ScaleDiscriminator.
    """
    def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
                 scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
        super().__init__()
        self.discriminators = nn.ModuleList([
            ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
        ])
        self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)

    @property
    def num_discriminators(self):
        return len(self.discriminators)

    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
        logits = []
        fmaps = []
        for i, disc in enumerate(self.discriminators):
            if i != 0:
                self.downsample(x)
            logit, fmap = disc(x)
            logits.append(logit)
            fmaps.append(fmap)
        return logits, fmaps


================================================
FILE: audiocraft/adversarial/discriminators/msstftd.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import torchaudio
import torch
from torch import nn
from einops import rearrange

from ...modules import NormConv2d
from .base import MultiDiscriminator, MultiDiscriminatorOutputType


def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
    return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)


class DiscriminatorSTFT(nn.Module):
    """STFT sub-discriminator.

    Args:
        filters (int): Number of filters in convolutions.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        n_fft (int): Size of FFT for each scale.
        hop_length (int): Length of hop between STFT windows for each scale.
        kernel_size (tuple of int): Inner Conv2d kernel sizes.
        stride (tuple of int): Inner Conv2d strides.
        dilations (list of int): Inner Conv2d dilation on the time dimension.
        win_length (int): Window size for each scale.
        normalized (bool): Whether to normalize by magnitude after stft.
        norm (str): Normalization method.
        activation (str): Activation function.
        activation_params (dict): Parameters to provide to the activation function.
        growth (int): Growth factor for the filters.
    """
    def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
                 n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
                 filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
                 stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
                 activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
        super().__init__()
        assert len(kernel_size) == 2
        assert len(stride) == 2
        self.filters = filters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.normalized = normalized
        self.activation = getattr(torch.nn, activation)(**activation_params)
        self.spec_transform = torchaudio.transforms.Spectrogram(
            n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
            normalized=self.normalized, center=False, pad_mode=None, power=None)
        spec_channels = 2 * self.in_channels
        self.convs = nn.ModuleList()
        self.convs.append(
            NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
        )
        in_chs = min(filters_scale * self.filters, max_filters)
        for i, dilation in enumerate(dilations):
            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
            self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
                                         dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
                                         norm=norm))
            in_chs = out_chs
        out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
        self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
                                     padding=get_2d_padding((kernel_size[0], kernel_size[0])),
                                     norm=norm))
        self.conv_post = NormConv2d(out_chs, self.out_channels,
                                    kernel_size=(kernel_size[0], kernel_size[0]),
                                    padding=get_2d_padding((kernel_size[0], kernel_size[0])),
                                    norm=norm)

    def forward(self, x: torch.Tensor):
        fmap = []
        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
        z = torch.cat([z.real, z.imag], dim=1)
        z = rearrange(z, 'b c w t -> b c t w')
        for i, layer in enumerate(self.convs):
            z = layer(z)
            z = self.activation(z)
            fmap.append(z)
        z = self.conv_post(z)
        return z, fmap


class MultiScaleSTFTDiscriminator(MultiDiscriminator):
    """Multi-Scale STFT (MS-STFT) discriminator.

    Args:
        filters (int): Number of filters in convolutions.
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        sep_channels (bool): Separate channels to distinct samples for stereo support.
        n_ffts (Sequence[int]): Size of FFT for each scale.
        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
        win_lengths (Sequence[int]): Window size for each scale.
        **kwargs: Additional args for STFTDiscriminator.
    """
    def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
                 n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
                 win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
        super().__init__()
        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
        self.sep_channels = sep_channels
        self.discriminators = nn.ModuleList([
            DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
                              n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
            for i in range(len(n_ffts))
        ])

    @property
    def num_discriminators(self):
        return len(self.discriminators)

    def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
        B, C, T = x.shape
        return x.view(-1, 1, T)

    def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
        logits = []
        fmaps = []
        for disc in self.discriminators:
            logit, fmap = disc(x)
            logits.append(logit)
            fmaps.append(fmap)
        return logits, fmaps


================================================
FILE: audiocraft/adversarial/losses.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Utility module to handle adversarial losses without requiring to mess up the main training loop.
"""

import typing as tp

import flashy
import torch
import torch.nn as nn
import torch.nn.functional as F


ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']


AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]


class AdversarialLoss(nn.Module):
    """Adversary training wrapper.

    Args:
        adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
            We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
            where the first item is a list of logits and the second item is a list of feature maps.
        optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
        loss (AdvLossType): Loss function for generator training.
        loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
        loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
        loss_feat (FeatLossType): Feature matching loss function for generator training.
        normalize (bool): Whether to normalize by number of sub-discriminators.

    Example of usage:
        adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
        for real in loader:
            noise = torch.randn(...)
            fake = model(noise)
            adv_loss.train_adv(fake, real)
            loss, _ = adv_loss(fake, real)
            loss.backward()
    """
    def __init__(self,
                 adversary: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 loss: AdvLossType,
                 loss_real: AdvLossType,
                 loss_fake: AdvLossType,
                 loss_feat: tp.Optional[FeatLossType] = None,
                 normalize: bool = True):
        super().__init__()
        self.adversary: nn.Module = adversary
        flashy.distrib.broadcast_model(self.adversary)
        self.optimizer = optimizer
        self.loss = loss
        self.loss_real = loss_real
        self.loss_fake = loss_fake
        self.loss_feat = loss_feat
        self.normalize = normalize

    def _save_to_state_dict(self, destination, prefix, keep_vars):
        # Add the optimizer state dict inside our own.
        super()._save_to_state_dict(destination, prefix, keep_vars)
        destination[prefix + 'optimizer'] = self.optimizer.state_dict()
        return destination

    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        # Load optimizer state.
        self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

    def get_adversary_pred(self, x):
        """Run adversary model, validating expected output format."""
        logits, fmaps = self.adversary(x)
        assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
            f'Expecting a list of tensors as logits but {type(logits)} found.'
        assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
        for fmap in fmaps:
            assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
                f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
        return logits, fmaps

    def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
        """Train the adversary with the given fake and real example.

        We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
        The first item being the logits and second item being a list of feature maps for each sub-discriminator.

        This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
        and call the optimizer.
        """
        loss = torch.tensor(0., device=fake.device)
        all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
        all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
        n_sub_adversaries = len(all_logits_fake_is_fake)
        for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
            loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)

        if self.normalize:
            loss /= n_sub_adversaries

        self.optimizer.zero_grad()
        with flashy.distrib.eager_sync_model(self.adversary):
            loss.backward()
        self.optimizer.step()

        return loss

    def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
        """Return the loss for the generator, i.e. trying to fool the adversary,
        and feature matching loss if provided.
        """
        adv = torch.tensor(0., device=fake.device)
        feat = torch.tensor(0., device=fake.device)
        with flashy.utils.readonly(self.adversary):
            all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
            all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
            n_sub_adversaries = len(all_logits_fake_is_fake)
            for logit_fake_is_fake in all_logits_fake_is_fake:
                adv += self.loss(logit_fake_is_fake)
            if self.loss_feat:
                for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
                    feat += self.loss_feat(fmap_fake, fmap_real)

        if self.normalize:
            adv /= n_sub_adversaries
            feat /= n_sub_adversaries

        return adv, feat


def get_adv_criterion(loss_type: str) -> tp.Callable:
    assert loss_type in ADVERSARIAL_LOSSES
    if loss_type == 'mse':
        return mse_loss
    elif loss_type == 'hinge':
        return hinge_loss
    elif loss_type == 'hinge2':
        return hinge2_loss
    raise ValueError('Unsupported loss')


def get_fake_criterion(loss_type: str) -> tp.Callable:
    assert loss_type in ADVERSARIAL_LOSSES
    if loss_type == 'mse':
        return mse_fake_loss
    elif loss_type in ['hinge', 'hinge2']:
        return hinge_fake_loss
    raise ValueError('Unsupported loss')


def get_real_criterion(loss_type: str) -> tp.Callable:
    assert loss_type in ADVERSARIAL_LOSSES
    if loss_type == 'mse':
        return mse_real_loss
    elif loss_type in ['hinge', 'hinge2']:
        return hinge_real_loss
    raise ValueError('Unsupported loss')


def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
    return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))


def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
    return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))


def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
    return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))


def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
    return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))


def mse_loss(x: torch.Tensor) -> torch.Tensor:
    if x.numel() == 0:
        return torch.tensor([0.0], device=x.device)
    return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))


def hinge_loss(x: torch.Tensor) -> torch.Tensor:
    if x.numel() == 0:
        return torch.tensor([0.0], device=x.device)
    return -x.mean()


def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
    if x.numel() == 0:
        return torch.tensor([0.0])
    return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))


class FeatureMatchingLoss(nn.Module):
    """Feature matching loss for adversarial training.

    Args:
        loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
        normalize (bool): Whether to normalize the loss.
            by number of feature maps.
    """
    def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
        super().__init__()
        self.loss = loss
        self.normalize = normalize

    def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
        assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
        feat_loss = torch.tensor(0., device=fmap_fake[0].device)
        feat_scale = torch.tensor(0., device=fmap_fake[0].device)
        n_fmaps = 0
        for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
            assert feat_fake.shape == feat_real.shape
            n_fmaps += 1
            feat_loss += self.loss(feat_fake, feat_real)
            feat_scale += torch.mean(torch.abs(feat_real))

        if self.normalize:
            feat_loss /= n_fmaps

        return feat_loss


================================================
FILE: audiocraft/data/__init__.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Audio loading and writing support. Datasets for raw audio
or also including some metadata."""

# flake8: noqa
from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset


================================================
FILE: audiocraft/data/audio.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Audio IO methods are defined in this module (info, read, write),
We rely on av library for faster read when possible, otherwise on torchaudio.
"""

from dataclasses import dataclass
from pathlib import Path
import logging
import typing as tp

import numpy as np
import soundfile
import torch
from torch.nn import functional as F
import torchaudio as ta

import av

from .audio_utils import f32_pcm, i16_pcm, normalize_audio


_av_initialized = False


def _init_av():
    global _av_initialized
    if _av_initialized:
        return
    logger = logging.getLogger('libav.mp3')
    logger.setLevel(logging.ERROR)
    _av_initialized = True


@dataclass(frozen=True)
class AudioFileInfo:
    sample_rate: int
    duration: float
    channels: int


def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
    _init_av()
    with av.open(str(filepath)) as af:
        stream = af.streams.audio[0]
        sample_rate = stream.codec_context.sample_rate
        duration = float(stream.duration * stream.time_base)
        channels = stream.channels
        return AudioFileInfo(sample_rate, duration, channels)


def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
    info = soundfile.info(filepath)
    return AudioFileInfo(info.samplerate, info.duration, info.channels)


def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
    # torchaudio no longer returns useful duration informations for some formats like mp3s.
    filepath = Path(filepath)
    if filepath.suffix in ['.flac', '.ogg']:  # TODO: Validate .ogg can be safely read with av_info
        # ffmpeg has some weird issue with flac.
        return _soundfile_info(filepath)
    else:
        return _av_info(filepath)


def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
    """FFMPEG-based audio file reading using PyAV bindings.
    Soundfile cannot read mp3 and av_read is more efficient than torchaudio.

    Args:
        filepath (str or Path): Path to audio file to read.
        seek_time (float): Time at which to start reading in the file.
        duration (float): Duration to read from the file. If set to -1, the whole file is read.
    Returns:
        tuple of torch.Tensor, int: Tuple containing audio data and sample rate
    """
    _init_av()
    with av.open(str(filepath)) as af:
        stream = af.streams.audio[0]
        sr = stream.codec_context.sample_rate
        num_frames = int(sr * duration) if duration >= 0 else -1
        frame_offset = int(sr * seek_time)
        # we need a small negative offset otherwise we get some edge artifact
        # from the mp3 decoder.
        af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
        frames = []
        length = 0
        for frame in af.decode(streams=stream.index):
            current_offset = int(frame.rate * frame.pts * frame.time_base)
            strip = max(0, frame_offset - current_offset)
            buf = torch.from_numpy(frame.to_ndarray())
            if buf.shape[0] != stream.channels:
                buf = buf.view(-1, stream.channels).t()
            buf = buf[:, strip:]
            frames.append(buf)
            length += buf.shape[1]
            if num_frames > 0 and length >= num_frames:
                break
        assert frames
        # If the above assert fails, it is likely because we seeked past the end of file point,
        # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
        # This will need proper debugging, in due time.
        wav = torch.cat(frames, dim=1)
        assert wav.shape[0] == stream.channels
        if num_frames > 0:
            wav = wav[:, :num_frames]
        return f32_pcm(wav), sr


def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
               duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
    """Read audio by picking the most appropriate backend tool based on the audio format.

    Args:
        filepath (str or Path): Path to audio file to read.
        seek_time (float): Time at which to start reading in the file.
        duration (float): Duration to read from the file. If set to -1, the whole file is read.
        pad (bool): Pad output audio if not reaching expected duration.
    Returns:
        tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
    """
    fp = Path(filepath)
    if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
        # There is some bug with ffmpeg and reading flac
        info = _soundfile_info(filepath)
        frames = -1 if duration <= 0 else int(duration * info.sample_rate)
        frame_offset = int(seek_time * info.sample_rate)
        wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
        assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
        wav = torch.from_numpy(wav).t().contiguous()
        if len(wav.shape) == 1:
            wav = torch.unsqueeze(wav, 0)
    elif (
        fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
        and duration <= 0 and seek_time == 0
    ):
        # Torchaudio is faster if we load an entire file at once.
        wav, sr = ta.load(fp)
    else:
        wav, sr = _av_read(filepath, seek_time, duration)
    if pad and duration > 0:
        expected_frames = int(duration * sr)
        wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
    return wav, sr


def audio_write(stem_name: tp.Union[str, Path],
                wav: torch.Tensor, sample_rate: int,
                format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
                loudness_compressor: bool = False,
                log_clipping: bool = True, make_parent_dir: bool = True,
                add_suffix: bool = True) -> Path:
    """Convenience function for saving audio to disk. Returns the filename the audio was written to.

    Args:
        stem_name (str or Path): Filename without extension which will be added automatically.
        format (str): Either "wav" or "mp3".
        mp3_rate (int): kbps when using mp3s.
        normalize (bool): if `True` (default), normalizes according to the prescribed
            strategy (see after). If `False`, the strategy is only used in case clipping
            would happen.
        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
            with extra headroom to avoid clipping. 'clip' just clips.
        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
            than the `peak_clip` one to avoid further clipping.
        loudness_headroom_db (float): Target loudness for loudness normalization.
        loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
         when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
            occurs despite strategy (only for 'rms').
        make_parent_dir (bool): Make parent directory if it doesn't exist.
    Returns:
        Path: Path of the saved audio.
    """
    assert wav.dtype.is_floating_point, "wav is not floating point"
    if wav.dim() == 1:
        wav = wav[None]
    elif wav.dim() > 2:
        raise ValueError("Input wav should be at most 2 dimension.")
    assert wav.isfinite().all()
    wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
                          rms_headroom_db, loudness_headroom_db, loudness_compressor,
                          log_clipping=log_clipping, sample_rate=sample_rate,
                          stem_name=str(stem_name))
    kwargs: dict = {}
    if format == 'mp3':
        suffix = '.mp3'
        kwargs.update({"compression": mp3_rate})
    elif format == 'wav':
        wav = i16_pcm(wav)
        suffix = '.wav'
        kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
    else:
        raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
    if not add_suffix:
        suffix = ''
    path = Path(str(stem_name) + suffix)
    if make_parent_dir:
        path.parent.mkdir(exist_ok=True, parents=True)
    try:
        ta.save(path, wav, sample_rate, **kwargs)
    except Exception:
        if path.exists():
            # we do not want to leave half written files around.
            path.unlink()
        raise
    return path


================================================
FILE: audiocraft/data/audio_dataset.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""AudioDataset support. In order to handle a larger number of files
without having to scan again the folders, we precompute some metadata
(filename, sample rate, duration), and use that to efficiently sample audio segments.
"""
import argparse
import copy
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, fields
from contextlib import ExitStack
from functools import lru_cache
import gzip
import json
import logging
import os
from pathlib import Path
import random
import sys
import typing as tp

import torch
import torch.nn.functional as F

from .audio import audio_read, audio_info
from .audio_utils import convert_audio
from .zip import PathInZip

try:
    import dora
except ImportError:
    dora = None  # type: ignore


@dataclass(order=True)
class BaseInfo:

    @classmethod
    def _dict2fields(cls, dictionary: dict):
        return {
            field.name: dictionary[field.name]
            for field in fields(cls) if field.name in dictionary
        }

    @classmethod
    def from_dict(cls, dictionary: dict):
        _dictionary = cls._dict2fields(dictionary)
        return cls(**_dictionary)

    def to_dict(self):
        return {
            field.name: self.__getattribute__(field.name)
            for field in fields(self)
            }


@dataclass(order=True)
class AudioMeta(BaseInfo):
    path: str
    duration: float
    sample_rate: int
    amplitude: tp.Optional[float] = None
    weight: tp.Optional[float] = None
    # info_path is used to load additional information about the audio file that is stored in zip files.
    info_path: tp.Optional[PathInZip] = None

    @classmethod
    def from_dict(cls, dictionary: dict):
        base = cls._dict2fields(dictionary)
        if 'info_path' in base and base['info_path'] is not None:
            base['info_path'] = PathInZip(base['info_path'])
        return cls(**base)

    def to_dict(self):
        d = super().to_dict()
        if d['info_path'] is not None:
            d['info_path'] = str(d['info_path'])
        return d


@dataclass(order=True)
class SegmentInfo(BaseInfo):
    meta: AudioMeta
    seek_time: float
    # The following values are given once the audio is processed, e.g.
    # at the target sample rate and target number of channels.
    n_frames: int      # actual number of frames without padding
    total_frames: int  # total number of frames, padding included
    sample_rate: int   # actual sample rate
    channels: int      # number of audio channels.


DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']

logger = logging.getLogger(__name__)


def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
    """AudioMeta from a path to an audio file.

    Args:
        file_path (str): Resolved path of valid audio file.
        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
    Returns:
        AudioMeta: Audio file path and its metadata.
    """
    info = audio_info(file_path)
    amplitude: tp.Optional[float] = None
    if not minimal:
        wav, sr = audio_read(file_path)
        amplitude = wav.abs().max().item()
    return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)


def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
    """If Dora is available as a dependency, try to resolve potential relative paths
    in list of AudioMeta. This method is expected to be used when loading meta from file.

    Args:
        m (AudioMeta): Audio meta to resolve.
        fast (bool): If True, uses a really fast check for determining if a file
            is already absolute or not. Only valid on Linux/Mac.
    Returns:
        AudioMeta: Audio meta with resolved path.
    """
    def is_abs(m):
        if fast:
            return str(m)[0] == '/'
        else:
            os.path.isabs(str(m))

    if not dora:
        return m

    if not is_abs(m.path):
        m.path = dora.git_save.to_absolute_path(m.path)
    if m.info_path is not None and not is_abs(m.info_path.zip_path):
        m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
    return m


def find_audio_files(path: tp.Union[Path, str],
                     exts: tp.List[str] = DEFAULT_EXTS,
                     resolve: bool = True,
                     minimal: bool = True,
                     progress: bool = False,
                     workers: int = 0) -> tp.List[AudioMeta]:
    """Build a list of AudioMeta from a given path,
    collecting relevant audio files and fetching meta info.

    Args:
        path (str or Path): Path to folder containing audio files.
        exts (list of str): List of file extensions to consider for audio files.
        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
        progress (bool): Whether to log progress on audio files collection.
        workers (int): number of parallel workers, if 0, use only the current thread.
    Returns:
        list of AudioMeta: List of audio file path and its metadata.
    """
    audio_files = []
    futures: tp.List[Future] = []
    pool: tp.Optional[ThreadPoolExecutor] = None
    with ExitStack() as stack:
        if workers > 0:
            pool = ThreadPoolExecutor(workers)
            stack.enter_context(pool)

        if progress:
            print("Finding audio files...")
        for root, folders, files in os.walk(path, followlinks=True):
            for file in files:
                full_path = Path(root) / file
                if full_path.suffix.lower() in exts:
                    audio_files.append(full_path)
                    if pool is not None:
                        futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
                    if progress:
                        print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)

        if progress:
            print("Getting audio metadata...")
        meta: tp.List[AudioMeta] = []
        for idx, file_path in enumerate(audio_files):
            try:
                if pool is None:
                    m = _get_audio_meta(str(file_path), minimal)
                else:
                    m = futures[idx].result()
                if resolve:
                    m = _resolve_audio_meta(m)
            except Exception as err:
                print("Error with", str(file_path), err, file=sys.stderr)
                continue
            meta.append(m)
            if progress:
                print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
    meta.sort()
    return meta


def load_audio_meta(path: tp.Union[str, Path],
                    resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
    """Load list of AudioMeta from an optionally compressed json file.

    Args:
        path (str or Path): Path to JSON file.
        resolve (bool): Whether to resolve the path from AudioMeta (default=True).
        fast (bool): activates some tricks to make things faster.
    Returns:
        list of AudioMeta: List of audio file path and its total duration.
    """
    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
    with open_fn(path, 'rb') as fp:  # type: ignore
        lines = fp.readlines()
    meta = []
    for line in lines:
        d = json.loads(line)
        m = AudioMeta.from_dict(d)
        if resolve:
            m = _resolve_audio_meta(m, fast=fast)
        meta.append(m)
    return meta


def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
    """Save the audio metadata to the file pointer as json.

    Args:
        path (str or Path): Path to JSON file.
        metadata (list of BaseAudioMeta): List of audio meta to save.
    """
    Path(path).parent.mkdir(exist_ok=True, parents=True)
    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
    with open_fn(path, 'wb') as fp:  # type: ignore
        for m in meta:
            json_str = json.dumps(m.to_dict()) + '\n'
            json_bytes = json_str.encode('utf-8')
            fp.write(json_bytes)


class AudioDataset:
    """Base audio dataset.

    The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
    and potentially additional information, by creating random segments from the list of audio
    files referenced in the metadata and applying minimal data pre-processing such as resampling,
    mixing of channels, padding, etc.

    If no segment_duration value is provided, the AudioDataset will return the full wav for each
    audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
    duration, applying padding if required.

    By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
    allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
    original audio meta.

    Note that you can call `start_epoch(epoch)` in order to get
    a deterministic "randomization" for `shuffle=True`.
    For a given epoch and dataset index, this will always return the same extract.
    You can get back some diversity by setting the `shuffle_seed` param.

    Args:
        meta (list of AudioMeta): List of audio files metadata.
        segment_duration (float, optional): Optional segment duration of audio to load.
            If not specified, the dataset will load the full audio segment from the file.
        shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
        sample_rate (int): Target sample rate of the loaded audio samples.
        channels (int): Target number of channels of the loaded audio samples.
        sample_on_duration (bool): Set to `True` to sample segments with probability
            dependent on audio file duration. This is only used if `segment_duration` is provided.
        sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
            `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
            of the file duration and file weight. This is only used if `segment_duration` is provided.
        min_segment_ratio (float): Minimum segment ratio to use when the audio file
            is shorter than the desired segment.
        max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
        return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
        min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
            audio shorter than this will be filtered out.
        max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
            audio longer than this will be filtered out.
        shuffle_seed (int): can be used to further randomize
        load_wav (bool): if False, skip loading the wav but returns a tensor of 0
            with the expected segment_duration (which must be provided if load_wav is False).
        permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
            are False. Will ensure a permutation on files when going through the dataset.
            In that case the epoch number must be provided in order for the model
            to continue the permutation across epochs. In that case, it is assumed
            that `num_samples = total_batch_size * num_updates_per_epoch`, with
            `total_batch_size` the overall batch size accounting for all gpus.
    """
    def __init__(self,
                 meta: tp.List[AudioMeta],
                 segment_duration: tp.Optional[float] = None,
                 shuffle: bool = True,
                 num_samples: int = 10_000,
                 sample_rate: int = 48_000,
                 channels: int = 2,
                 pad: bool = True,
                 sample_on_duration: bool = True,
                 sample_on_weight: bool = True,
                 min_segment_ratio: float = 0.5,
                 max_read_retry: int = 10,
                 return_info: bool = False,
                 min_audio_duration: tp.Optional[float] = None,
                 max_audio_duration: tp.Optional[float] = None,
                 shuffle_seed: int = 0,
                 load_wav: bool = True,
                 permutation_on_files: bool = False,
                 ):
        assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
        assert segment_duration is None or segment_duration > 0
        assert segment_duration is None or min_segment_ratio >= 0
        self.segment_duration = segment_duration
        self.min_segment_ratio = min_segment_ratio
        self.max_audio_duration = max_audio_duration
        self.min_audio_duration = min_audio_duration
        if self.min_audio_duration is not None and self.max_audio_duration is not None:
            assert self.min_audio_duration <= self.max_audio_duration
        self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
        assert len(self.meta)  # Fail fast if all data has been filtered.
        self.total_duration = sum(d.duration for d in self.meta)

        if segment_duration is None:
            num_samples = len(self.meta)
        self.num_samples = num_samples
        self.shuffle = shuffle
        self.sample_rate = sample_rate
        self.channels = channels
        self.pad = pad
        self.sample_on_weight = sample_on_weight
        self.sample_on_duration = sample_on_duration
        self.sampling_probabilities = self._get_sampling_probabilities()
        self.max_read_retry = max_read_retry
        self.return_info = return_info
        self.shuffle_seed = shuffle_seed
        self.current_epoch: tp.Optional[int] = None
        self.load_wav = load_wav
        if not load_wav:
            assert segment_duration is not None
        self.permutation_on_files = permutation_on_files
        if permutation_on_files:
            assert not self.sample_on_duration
            assert not self.sample_on_weight
            assert self.shuffle

    def start_epoch(self, epoch: int):
        self.current_epoch = epoch

    def __len__(self):
        return self.num_samples

    def _get_sampling_probabilities(self, normalized: bool = True):
        """Return the sampling probabilities for each file inside `self.meta`."""
        scores: tp.List[float] = []
        for file_meta in self.meta:
            score = 1.
            if self.sample_on_weight and file_meta.weight is not None:
                score *= file_meta.weight
            if self.sample_on_duration:
                score *= file_meta.duration
            scores.append(score)
        probabilities = torch.tensor(scores)
        if normalized:
            probabilities /= probabilities.sum()
        return probabilities

    @staticmethod
    @lru_cache(16)
    def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
        # Used to keep the most recent files permutation in memory implicitely.
        # will work unless someone is using a lot of Datasets in parallel.
        rng = torch.Generator()
        rng.manual_seed(base_seed + permutation_index)
        return torch.randperm(num_files, generator=rng)

    def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
        """Sample a given file from `self.meta`. Can be overridden in subclasses.
        This is only called if `segment_duration` is not None.

        You must use the provided random number generator `rng` for reproducibility.
        You can further make use of the index accessed.
        """
        if self.permutation_on_files:
            assert self.current_epoch is not None
            total_index = self.current_epoch * len(self) + index
            permutation_index = total_index // len(self.meta)
            relative_index = total_index % len(self.meta)
            permutation = AudioDataset._get_file_permutation(
                len(self.meta), permutation_index, self.shuffle_seed)
            file_index = permutation[relative_index]
            return self.meta[file_index]

        if not self.sample_on_weight and not self.sample_on_duration:
            file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
        else:
            file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())

        return self.meta[file_index]

    def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
        # Override this method in subclass if needed.
        if self.load_wav:
            return audio_read(path, seek_time, duration, pad=False)
        else:
            assert self.segment_duration is not None
            n_frames = int(self.sample_rate * self.segment_duration)
            return torch.zeros(self.channels, n_frames), self.sample_rate

    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
        if self.segment_duration is None:
            file_meta = self.meta[index]
            out, sr = audio_read(file_meta.path)
            out = convert_audio(out, sr, self.sample_rate, self.channels)
            n_frames = out.shape[-1]
            segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
                                       sample_rate=self.sample_rate, channels=out.shape[0])
        else:
            rng = torch.Generator()
            if self.shuffle:
                # We use index, plus extra randomness, either totally random if we don't know the epoch.
                # otherwise we make use of the epoch number and optional shuffle_seed.
                if self.current_epoch is None:
                    rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
                else:
                    rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
            else:
                # We only use index
                rng.manual_seed(index)

            for retry in range(self.max_read_retry):
                file_meta = self.sample_file(index, rng)
                # We add some variance in the file position even if audio file is smaller than segment
                # without ending up with empty segments
                max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
                seek_time = torch.rand(1, generator=rng).item() * max_seek
                try:
                    out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
                    out = convert_audio(out, sr, self.sample_rate, self.channels)
                    n_frames = out.shape[-1]
                    target_frames = int(self.segment_duration * self.sample_rate)
                    if self.pad:
                        out = F.pad(out, (0, target_frames - n_frames))
                    segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
                                               sample_rate=self.sample_rate, channels=out.shape[0])
                except Exception as exc:
                    logger.warning("Error opening file %s: %r", file_meta.path, exc)
                    if retry == self.max_read_retry - 1:
                        raise
                else:
                    break

        if self.return_info:
            # Returns the wav and additional information on the wave segment
            return out, segment_info
        else:
            return out

    def collater(self, samples):
        """The collater function has to be provided to the dataloader
        if AudioDataset has return_info=True in order to properly collate
        the samples of a batch.
        """
        if self.segment_duration is None and len(samples) > 1:
            assert self.pad, "Must allow padding when batching examples of different durations."

        # In this case the audio reaching the collater is of variable length as segment_duration=None.
        to_pad = self.segment_duration is None and self.pad
        if to_pad:
            max_len = max([wav.shape[-1] for wav, _ in samples])

            def _pad_wav(wav):
                return F.pad(wav, (0, max_len - wav.shape[-1]))

        if self.return_info:
            if len(samples) > 0:
                assert len(samples[0]) == 2
                assert isinstance(samples[0][0], torch.Tensor)
                assert isinstance(samples[0][1], SegmentInfo)

            wavs = [wav for wav, _ in samples]
            segment_infos = [copy.deepcopy(info) for _, info in samples]

            if to_pad:
                # Each wav could be of a different duration as they are not segmented.
                for i in range(len(samples)):
                    # Determines the total length of the signal with padding, so we update here as we pad.
                    segment_infos[i].total_frames = max_len
                    wavs[i] = _pad_wav(wavs[i])

            wav = torch.stack(wavs)
            return wav, segment_infos
        else:
            assert isinstance(samples[0], torch.Tensor)
            if to_pad:
                samples = [_pad_wav(s) for s in samples]
            return torch.stack(samples)

    def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
        """Filters out audio files with audio durations that will not allow to sample examples from them."""
        orig_len = len(meta)

        # Filter data that is too short.
        if self.min_audio_duration is not None:
            meta = [m for m in meta if m.duration >= self.min_audio_duration]

        # Filter data that is too long.
        if self.max_audio_duration is not None:
            meta = [m for m in meta if m.duration <= self.max_audio_duration]

        filtered_len = len(meta)
        removed_percentage = 100*(1-float(filtered_len)/orig_len)
        msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
        if removed_percentage < 10:
            logging.debug(msg)
        else:
            logging.warning(msg)
        return meta

    @classmethod
    def from_meta(cls, root: tp.Union[str, Path], **kwargs):
        """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.

        Args:
            root (str or Path): Path to root folder containing audio files.
            kwargs: Additional keyword arguments for the AudioDataset.
        """
        root = Path(root)
        if root.is_dir():
            if (root / 'data.jsonl').exists():
                root = root / 'data.jsonl'
            elif (root / 'data.jsonl.gz').exists():
                root = root / 'data.jsonl.gz'
            else:
                raise ValueError("Don't know where to read metadata from in the dir. "
                                 "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
        meta = load_audio_meta(root)
        return cls(meta, **kwargs)

    @classmethod
    def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
                  exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
        """Instantiate AudioDataset from a path containing (possibly nested) audio files.

        Args:
            root (str or Path): Path to root folder containing audio files.
            minimal_meta (bool): Whether to only load minimal metadata or not.
            exts (list of str): Extensions for audio files.
            kwargs: Additional keyword arguments for the AudioDataset.
        """
        root = Path(root)
        if root.is_file():
            meta = load_audio_meta(root, resolve=True)
        else:
            meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
        return cls(meta, **kwargs)


def main():
    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
    parser = argparse.ArgumentParser(
        prog='audio_dataset',
        description='Generate .jsonl files by scanning a folder.')
    parser.add_argument('root', help='Root folder with all the audio files')
    parser.add_argument('output_meta_file',
                        help='Output file to store the metadata, ')
    parser.add_argument('--complete',
                        action='store_false', dest='minimal', default=True,
                        help='Retrieve all metadata, even the one that are expansive '
                             'to compute (e.g. normalization).')
    parser.add_argument('--resolve',
                        action='store_true', default=False,
                        help='Resolve the paths to be absolute and with no symlinks.')
    parser.add_argument('--workers',
                        default=10, type=int,
                        help='Number of workers.')
    args = parser.parse_args()
    meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
                            resolve=args.resolve, minimal=args.minimal, workers=args.workers)
    save_audio_meta(args.output_meta_file, meta)


if __name__ == '__main__':
    main()


================================================
FILE: audiocraft/data/audio_utils.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Various utilities for audio convertion (pcm format, sample rate and channels),
and volume normalization."""
import sys
import typing as tp

import julius
import torch
import torchaudio


def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
    """Convert audio to the given number of channels.

    Args:
        wav (torch.Tensor): Audio wave of shape [B, C, T].
        channels (int): Expected number of channels as output.
    Returns:
        torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
    """
    *shape, src_channels, length = wav.shape
    if src_channels == channels:
        pass
    elif channels == 1:
        # Case 1:
        # The caller asked 1-channel audio, and the stream has multiple
        # channels, downmix all channels.
        wav = wav.mean(dim=-2, keepdim=True)
    elif src_channels == 1:
        # Case 2:
        # The caller asked for multiple channels, but the input file has
        # a single channel, replicate the audio over all channels.
        wav = wav.expand(*shape, channels, length)
    elif src_channels >= channels:
        # Case 3:
        # The caller asked for multiple channels, and the input file has
        # more channels than requested. In that case return the first channels.
        wav = wav[..., :channels, :]
    else:
        # Case 4: What is a reasonable choice here?
        raise ValueError('The audio file has less channels than requested but is not mono.')
    return wav


def convert_audio(wav: torch.Tensor, from_rate: float,
                  to_rate: float, to_channels: int) -> torch.Tensor:
    """Convert audio to new sample rate and number of audio channels."""
    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
    wav = convert_audio_channels(wav, to_channels)
    return wav


def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
      
Download .txt
gitextract_7_a5iyu7/

├── .github/
│   └── actions/
│       └── audiocraft_build/
│           └── action.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── LICENSE_weights
├── MANIFEST.in
├── Makefile
├── README.md
├── app.py
├── audiocraft/
│   ├── __init__.py
│   ├── adversarial/
│   │   ├── __init__.py
│   │   ├── discriminators/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── mpd.py
│   │   │   ├── msd.py
│   │   │   └── msstftd.py
│   │   └── losses.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── audio.py
│   │   ├── audio_dataset.py
│   │   ├── audio_utils.py
│   │   ├── info_audio_dataset.py
│   │   ├── music_dataset.py
│   │   ├── sound_dataset.py
│   │   └── zip.py
│   ├── environment.py
│   ├── grids/
│   │   ├── __init__.py
│   │   ├── _base_explorers.py
│   │   ├── audiogen/
│   │   │   ├── __init__.py
│   │   │   ├── audiogen_base_16khz.py
│   │   │   └── audiogen_pretrained_16khz_eval.py
│   │   ├── compression/
│   │   │   ├── __init__.py
│   │   │   ├── _explorers.py
│   │   │   ├── debug.py
│   │   │   ├── encodec_audiogen_16khz.py
│   │   │   ├── encodec_base_24khz.py
│   │   │   └── encodec_musicgen_32khz.py
│   │   ├── diffusion/
│   │   │   ├── 4_bands_base_32khz.py
│   │   │   ├── __init__.py
│   │   │   └── _explorers.py
│   │   └── musicgen/
│   │       ├── __init__.py
│   │       ├── _explorers.py
│   │       ├── musicgen_base_32khz.py
│   │       ├── musicgen_base_cached_32khz.py
│   │       ├── musicgen_clapemb_32khz.py
│   │       ├── musicgen_melody_32khz.py
│   │       └── musicgen_pretrained_32khz_eval.py
│   ├── losses/
│   │   ├── __init__.py
│   │   ├── balancer.py
│   │   ├── sisnr.py
│   │   ├── specloss.py
│   │   └── stftloss.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── chroma_cosinesim.py
│   │   ├── clap_consistency.py
│   │   ├── fad.py
│   │   ├── kld.py
│   │   ├── rvm.py
│   │   └── visqol.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── audiogen.py
│   │   ├── builders.py
│   │   ├── encodec.py
│   │   ├── lm.py
│   │   ├── loaders.py
│   │   ├── multibanddiffusion.py
│   │   ├── musicgen.py
│   │   └── unet.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── activations.py
│   │   ├── chroma.py
│   │   ├── codebooks_patterns.py
│   │   ├── conditioners.py
│   │   ├── conv.py
│   │   ├── diffusion_schedule.py
│   │   ├── lstm.py
│   │   ├── rope.py
│   │   ├── seanet.py
│   │   ├── streaming.py
│   │   └── transformer.py
│   ├── optim/
│   │   ├── __init__.py
│   │   ├── cosine_lr_scheduler.py
│   │   ├── dadam.py
│   │   ├── ema.py
│   │   ├── fsdp.py
│   │   ├── inverse_sqrt_lr_scheduler.py
│   │   ├── linear_warmup_lr_scheduler.py
│   │   └── polynomial_decay_lr_scheduler.py
│   ├── py.typed
│   ├── quantization/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── core_vq.py
│   │   └── vq.py
│   ├── solvers/
│   │   ├── __init__.py
│   │   ├── audiogen.py
│   │   ├── base.py
│   │   ├── builders.py
│   │   ├── compression.py
│   │   ├── diffusion.py
│   │   └── musicgen.py
│   ├── train.py
│   └── utils/
│       ├── __init__.py
│       ├── autocast.py
│       ├── best_state.py
│       ├── cache.py
│       ├── checkpoint.py
│       ├── cluster.py
│       ├── deadlock.py
│       ├── export.py
│       ├── export_legacy.py
│       ├── notebook.py
│       ├── profiler.py
│       ├── samples/
│       │   ├── __init__.py
│       │   └── manager.py
│       ├── ui.py
│       └── utils.py
├── config/
│   ├── conditioner/
│   │   ├── chroma2music.yaml
│   │   ├── clapemb2music.yaml
│   │   ├── none.yaml
│   │   ├── text2music.yaml
│   │   └── text2sound.yaml
│   ├── config.yaml
│   ├── dset/
│   │   ├── audio/
│   │   │   ├── audiocaps_16khz.yaml
│   │   │   ├── default.yaml
│   │   │   ├── example.yaml
│   │   │   └── musiccaps_32khz.yaml
│   │   ├── default.yaml
│   │   └── internal/
│   │       ├── music_10k_32khz.yaml
│   │       ├── music_400k_32khz.yaml
│   │       └── sounds_16khz.yaml
│   ├── model/
│   │   ├── encodec/
│   │   │   ├── default.yaml
│   │   │   ├── encodec_base_causal.yaml
│   │   │   ├── encodec_large_nq4_s320.yaml
│   │   │   └── encodec_large_nq4_s640.yaml
│   │   ├── lm/
│   │   │   ├── audiogen_lm.yaml
│   │   │   ├── default.yaml
│   │   │   ├── model_scale/
│   │   │   │   ├── base.yaml
│   │   │   │   ├── large.yaml
│   │   │   │   ├── medium.yaml
│   │   │   │   ├── small.yaml
│   │   │   │   └── xsmall.yaml
│   │   │   └── musicgen_lm.yaml
│   │   ├── none.yaml
│   │   └── score/
│   │       └── basic.yaml
│   ├── solver/
│   │   ├── audiogen/
│   │   │   ├── audiogen_base_16khz.yaml
│   │   │   ├── debug.yaml
│   │   │   ├── default.yaml
│   │   │   └── evaluation/
│   │   │       ├── none.yaml
│   │   │       └── objective_eval.yaml
│   │   ├── compression/
│   │   │   ├── debug.yaml
│   │   │   ├── default.yaml
│   │   │   ├── encodec_audiogen_16khz.yaml
│   │   │   ├── encodec_base_24khz.yaml
│   │   │   └── encodec_musicgen_32khz.yaml
│   │   ├── default.yaml
│   │   ├── diffusion/
│   │   │   ├── debug.yaml
│   │   │   ├── default.yaml
│   │   │   └── encodec_24khz.yaml
│   │   └── musicgen/
│   │       ├── debug.yaml
│   │       ├── default.yaml
│   │       ├── evaluation/
│   │       │   ├── none.yaml
│   │       │   └── objective_eval.yaml
│   │       ├── musicgen_base_32khz.yaml
│   │       └── musicgen_melody_32khz.yaml
│   └── teams/
│       ├── default.yaml
│       └── labs.yaml
├── dataset/
│   └── example/
│       ├── electro_1.json
│       └── electro_2.json
├── demos/
│   ├── audiogen_demo.ipynb
│   ├── musicgen_app.py
│   └── musicgen_demo.ipynb
├── dockerignore
├── docs/
│   ├── AUDIOGEN.md
│   ├── CONDITIONING.md
│   ├── DATASETS.md
│   ├── ENCODEC.md
│   ├── MBD.md
│   ├── METRICS.md
│   ├── MUSICGEN.md
│   └── TRAINING.md
├── egs/
│   └── example/
│       └── data.jsonl
├── model_cards/
│   ├── AUDIOGEN_MODEL_CARD.md
│   └── MUSICGEN_MODEL_CARD.md
├── models/
│   └── Put your models here.txt
├── mypy.ini
├── requirements.txt
├── scripts/
│   ├── __init__.py
│   ├── mos.py
│   ├── resample_dataset.py
│   ├── static/
│   │   └── style.css
│   └── templates/
│       ├── base.html
│       ├── index.html
│       ├── login.html
│       ├── results.html
│       └── survey.html
├── setup.cfg
├── setup.py
└── tests/
    ├── __init__.py
    ├── adversarial/
    │   ├── __init__.py
    │   ├── test_discriminators.py
    │   └── test_losses.py
    ├── common_utils/
    │   ├── __init__.py
    │   ├── temp_utils.py
    │   └── wav_utils.py
    ├── data/
    │   ├── __init__.py
    │   ├── test_audio.py
    │   ├── test_audio_dataset.py
    │   └── test_audio_utils.py
    ├── losses/
    │   ├── __init__.py
    │   └── test_losses.py
    ├── models/
    │   ├── test_audiogen.py
    │   ├── test_encodec_model.py
    │   ├── test_multibanddiffusion.py
    │   └── test_musicgen.py
    ├── modules/
    │   ├── __init__.py
    │   ├── test_activations.py
    │   ├── test_codebooks_patterns.py
    │   ├── test_conv.py
    │   ├── test_lstm.py
    │   ├── test_rope.py
    │   ├── test_seanet.py
    │   └── test_transformer.py
    ├── quantization/
    │   └── test_vq.py
    └── utils/
        └── __init__.py
Download .txt
SYMBOL INDEX (1239 symbols across 112 files)

FILE: app.py
  function generate_random_string (line 66) | def generate_random_string(length):
  function resize_video (line 71) | def resize_video(input_path, output_path, target_width, target_height):
  function _call_nostderr (line 83) | def _call_nostderr(*args, **kwargs):
  function interrupt (line 96) | def interrupt():
  class FileCleaner (line 101) | class FileCleaner:
    method __init__ (line 102) | def __init__(self, file_lifetime: float = 3600):
    method add (line 106) | def add(self, path: tp.Union[str, Path]):
    method _cleanup (line 110) | def _cleanup(self):
  function make_waveform (line 124) | def make_waveform(*args, **kwargs):
  function load_model (line 146) | def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=No...
  function get_audio_info (line 180) | def get_audio_info(audio_path):
  function info_to_params (line 250) | def info_to_params(audio_path):
  function info_to_params_a (line 348) | def info_to_params_a(audio_path):
  function make_pseudo_stereo (line 435) | def make_pseudo_stereo (filename, sr_select, pan, delay):
  function normalize_audio (line 453) | def normalize_audio(audio_data):
  function load_diffusion (line 460) | def load_diffusion():
  function unload_diffusion (line 467) | def unload_diffusion():
  function _do_predictions (line 474) | def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_...
  function predict_batched (line 638) | def predict_batched(texts, melodies):
  function add_tags (line 646) | def add_tags(filename, tags):
  function save_outputs (line 686) | def save_outputs(mp4, wav_tmp, tags, gen_type):
  function clear_cash (line 735) | def clear_cash():
  function s2t (line 766) | def s2t(seconds, seconds2):
  function calc_time (line 777) | def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6...
  function predict_full (line 816) | def predict_full(gen_type, model, decoder, custom_model, prompt_amount, ...
  function get_available_folders (line 935) | def get_available_folders():
  function toggle_audio_src (line 941) | def toggle_audio_src(choice):
  function ui_full (line 948) | def ui_full(launch_kwargs):
  function ui_batched (line 1695) | def ui_batched(launch_kwargs):

FILE: audiocraft/adversarial/discriminators/base.py
  class MultiDiscriminator (line 19) | class MultiDiscriminator(ABC, nn.Module):
    method __init__ (line 22) | def __init__(self):
    method forward (line 26) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
    method num_discriminators (line 31) | def num_discriminators(self) -> int:

FILE: audiocraft/adversarial/discriminators/mpd.py
  function get_padding (line 17) | def get_padding(kernel_size: int, dilation: int = 1) -> int:
  class PeriodDiscriminator (line 21) | class PeriodDiscriminator(nn.Module):
    method __init__ (line 38) | def __init__(self, period: int, in_channels: int = 1, out_channels: in...
    method forward (line 58) | def forward(self, x: torch.Tensor):
  class MultiPeriodDiscriminator (line 79) | class MultiPeriodDiscriminator(MultiDiscriminator):
    method __init__ (line 88) | def __init__(self, in_channels: int = 1, out_channels: int = 1,
    method num_discriminators (line 96) | def num_discriminators(self):
    method forward (line 99) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:

FILE: audiocraft/adversarial/discriminators/msd.py
  class ScaleDiscriminator (line 17) | class ScaleDiscriminator(nn.Module):
    method __init__ (line 37) | def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Seq...
    method forward (line 83) | def forward(self, x: torch.Tensor):
  class MultiScaleDiscriminator (line 95) | class MultiScaleDiscriminator(MultiDiscriminator):
    method __init__ (line 105) | def __init__(self, in_channels: int = 1, out_channels: int = 1, downsa...
    method num_discriminators (line 114) | def num_discriminators(self):
    method forward (line 117) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:

FILE: audiocraft/adversarial/discriminators/msstftd.py
  function get_2d_padding (line 18) | def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[i...
  class DiscriminatorSTFT (line 22) | class DiscriminatorSTFT(nn.Module):
    method __init__ (line 41) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i...
    method forward (line 81) | def forward(self, x: torch.Tensor):
  class MultiScaleSTFTDiscriminator (line 94) | class MultiScaleSTFTDiscriminator(MultiDiscriminator):
    method __init__ (line 107) | def __init__(self, filters: int, in_channels: int = 1, out_channels: i...
    method num_discriminators (line 120) | def num_discriminators(self):
    method _separate_channels (line 123) | def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
    method forward (line 127) | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:

FILE: audiocraft/adversarial/losses.py
  class AdversarialLoss (line 26) | class AdversarialLoss(nn.Module):
    method __init__ (line 49) | def __init__(self,
    method _save_to_state_dict (line 67) | def _save_to_state_dict(self, destination, prefix, keep_vars):
    method _load_from_state_dict (line 73) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
    method get_adversary_pred (line 78) | def get_adversary_pred(self, x):
    method train_adv (line 89) | def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.T...
    method forward (line 115) | def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[...
  function get_adv_criterion (line 138) | def get_adv_criterion(loss_type: str) -> tp.Callable:
  function get_fake_criterion (line 149) | def get_fake_criterion(loss_type: str) -> tp.Callable:
  function get_real_criterion (line 158) | def get_real_criterion(loss_type: str) -> tp.Callable:
  function mse_real_loss (line 167) | def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
  function mse_fake_loss (line 171) | def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
  function hinge_real_loss (line 175) | def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
  function hinge_fake_loss (line 179) | def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
  function mse_loss (line 183) | def mse_loss(x: torch.Tensor) -> torch.Tensor:
  function hinge_loss (line 189) | def hinge_loss(x: torch.Tensor) -> torch.Tensor:
  function hinge2_loss (line 195) | def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
  class FeatureMatchingLoss (line 201) | class FeatureMatchingLoss(nn.Module):
    method __init__ (line 209) | def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: boo...
    method forward (line 214) | def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List...

FILE: audiocraft/data/audio.py
  function _init_av (line 31) | def _init_av():
  class AudioFileInfo (line 41) | class AudioFileInfo:
  function _av_info (line 47) | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
  function _soundfile_info (line 57) | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
  function audio_info (line 62) | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
  function _av_read (line 72) | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, durati...
  function audio_read (line 116) | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
  function audio_write (line 153) | def audio_write(stem_name: tp.Union[str, Path],

FILE: audiocraft/data/audio_dataset.py
  class BaseInfo (line 39) | class BaseInfo:
    method _dict2fields (line 42) | def _dict2fields(cls, dictionary: dict):
    method from_dict (line 49) | def from_dict(cls, dictionary: dict):
    method to_dict (line 53) | def to_dict(self):
  class AudioMeta (line 61) | class AudioMeta(BaseInfo):
    method from_dict (line 71) | def from_dict(cls, dictionary: dict):
    method to_dict (line 77) | def to_dict(self):
  class SegmentInfo (line 85) | class SegmentInfo(BaseInfo):
  function _get_audio_meta (line 101) | def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
  function _resolve_audio_meta (line 118) | def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
  function find_audio_files (line 145) | def find_audio_files(path: tp.Union[Path, str],
  function load_audio_meta (line 204) | def load_audio_meta(path: tp.Union[str, Path],
  function save_audio_meta (line 228) | def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
  class AudioDataset (line 244) | class AudioDataset:
    method __init__ (line 295) | def __init__(self,
    method start_epoch (line 350) | def start_epoch(self, epoch: int):
    method __len__ (line 353) | def __len__(self):
    method _get_sampling_probabilities (line 356) | def _get_sampling_probabilities(self, normalized: bool = True):
    method _get_file_permutation (line 373) | def _get_file_permutation(num_files: int, permutation_index: int, base...
    method sample_file (line 380) | def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
    method _audio_read (line 404) | def _audio_read(self, path: str, seek_time: float = 0, duration: float...
    method __getitem__ (line 413) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t...
    method collater (line 462) | def collater(self, samples):
    method _filter_duration (line 502) | def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioM...
    method from_meta (line 524) | def from_meta(cls, root: tp.Union[str, Path], **kwargs):
    method from_path (line 544) | def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
  function main (line 562) | def main():

FILE: audiocraft/data/audio_utils.py
  function convert_audio_channels (line 16) | def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torc...
  function convert_audio (line 49) | def convert_audio(wav: torch.Tensor, from_rate: float,
  function normalize_loudness (line 57) | def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_hea...
  function _clip_wav (line 86) | def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: ...
  function normalize_audio (line 97) | def normalize_audio(wav: torch.Tensor, normalize: bool = True,
  function f32_pcm (line 149) | def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
  function i16_pcm (line 161) | def i16_pcm(wav: torch.Tensor) -> torch.Tensor:

FILE: audiocraft/data/info_audio_dataset.py
  function _clusterify_meta (line 25) | def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
  function clusterify_all_meta (line 33) | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
  class AudioInfo (line 39) | class AudioInfo(SegmentWithAttributes):
    method to_condition_attributes (line 50) | def to_condition_attributes(self) -> ConditioningAttributes:
  class InfoAudioDataset (line 54) | class InfoAudioDataset(AudioDataset):
    method __init__ (line 59) | def __init__(self, meta: tp.List[AudioMeta], **kwargs):
    method __getitem__ (line 62) | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[t...
  function get_keyword_or_keyword_list (line 71) | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp....
  function get_string (line 79) | def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
  function get_keyword (line 87) | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
  function get_keyword_list (line 95) | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional...

FILE: audiocraft/data/music_dataset.py
  class MusicInfo (line 37) | class MusicInfo(AudioInfo):
    method has_music_meta (line 57) | def has_music_meta(self) -> bool:
    method to_condition_attributes (line 60) | def to_condition_attributes(self) -> ConditioningAttributes:
    method attribute_getter (line 76) | def attribute_getter(attribute):
    method from_dict (line 92) | def from_dict(cls, dictionary: dict, fields_required: bool = False):
  function augment_music_info_description (line 115) | def augment_music_info_description(music_info: MusicInfo, merge_text_p: ...
  class Paraphraser (line 167) | class Paraphraser:
    method __init__ (line 168) | def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_...
    method sample_paraphrase (line 175) | def sample_paraphrase(self, audio_path: str, description: str):
  class MusicDataset (line 187) | class MusicDataset(InfoAudioDataset):
    method __init__ (line 204) | def __init__(self, *args, info_fields_required: bool = True,
    method __getitem__ (line 220) | def __getitem__(self, index):
  function get_musical_key (line 252) | def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
  function get_bpm (line 263) | def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:

FILE: audiocraft/data/sound_dataset.py
  class SoundInfo (line 35) | class SoundInfo(SegmentWithAttributes):
    method has_sound_meta (line 42) | def has_sound_meta(self) -> bool:
    method to_condition_attributes (line 45) | def to_condition_attributes(self) -> ConditioningAttributes:
    method attribute_getter (line 57) | def attribute_getter(attribute):
    method from_dict (line 65) | def from_dict(cls, dictionary: dict, fields_required: bool = False):
  class SoundDataset (line 87) | class SoundDataset(InfoAudioDataset):
    method __init__ (line 104) | def __init__(
    method _get_info_path (line 129) | def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
    method __getitem__ (line 142) | def __getitem__(self, index):
    method collater (line 163) | def collater(self, samples):
  function rms_f (line 173) | def rms_f(x: torch.Tensor) -> torch.Tensor:
  function normalize (line 177) | def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Ten...
  function is_clipped (line 185) | def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) ->...
  function mix_pair (line 189) | def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -...
  function snr_mixer (line 199) | def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_ov...
  function snr_mix (line 252) | def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high...
  function mix_text (line 261) | def mix_text(src_text: str, dst_text: str):
  function mix_samples (line 268) | def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: fl...

FILE: audiocraft/data/zip.py
  class PathInZip (line 22) | class PathInZip:
    method __init__ (line 36) | def __init__(self, path: str) -> None:
    method from_paths (line 42) | def from_paths(cls, zip_path: str, file_path: str):
    method __str__ (line 45) | def __str__(self) -> str:
  function _open_zip (line 49) | def _open_zip(path: str, mode: MODE = 'r'):
  function set_zip_cache_size (line 56) | def set_zip_cache_size(max_size: int):
  function open_file_in_zip (line 66) | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:

FILE: audiocraft/environment.py
  class AudioCraftEnvironment (line 25) | class AudioCraftEnvironment:
    method __init__ (line 49) | def __init__(self) -> None:
    method _get_cluster_config (line 74) | def _get_cluster_config(self) -> omegaconf.DictConfig:
    method instance (line 79) | def instance(cls):
    method reset (line 85) | def reset(cls):
    method get_team (line 90) | def get_team(cls) -> str:
    method get_cluster (line 97) | def get_cluster(cls) -> str:
    method get_dora_dir (line 104) | def get_dora_dir(cls) -> Path:
    method get_reference_dir (line 114) | def get_reference_dir(cls) -> Path:
    method get_slurm_exclude (line 122) | def get_slurm_exclude(cls) -> tp.Optional[str]:
    method get_slurm_partitions (line 128) | def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str...
    method resolve_reference_path (line 146) | def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
    method apply_dataset_mappers (line 167) | def apply_dataset_mappers(cls, path: str) -> str:

FILE: audiocraft/grids/_base_explorers.py
  function get_sheep_ping (line 14) | def get_sheep_ping(sheep) -> tp.Optional[str]:
  class BaseExplorer (line 31) | class BaseExplorer(ABC, Explorer):
    method stages (line 40) | def stages(self):
    method get_grid_meta (line 43) | def get_grid_meta(self):
    method get_grid_metrics (line 55) | def get_grid_metrics(self):
    method process_sheep (line 60) | def process_sheep(self, sheep, history):

FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py
  function explorer (line 12) | def explorer(launcher):

FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py
  function eval (line 26) | def eval(launcher, batch_size: int = 32):
  function explorer (line 49) | def explorer(launcher):

FILE: audiocraft/grids/compression/_explorers.py
  class CompressionExplorer (line 12) | class CompressionExplorer(BaseExplorer):
    method stages (line 15) | def stages(self):
    method get_grid_meta (line 18) | def get_grid_meta(self):
    method get_grid_metrics (line 28) | def get_grid_metrics(self):

FILE: audiocraft/grids/compression/debug.py
  function explorer (line 22) | def explorer(launcher):

FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py
  function explorer (line 20) | def explorer(launcher):

FILE: audiocraft/grids/compression/encodec_base_24khz.py
  function explorer (line 20) | def explorer(launcher):

FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py
  function explorer (line 20) | def explorer(launcher):

FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py
  function explorer (line 17) | def explorer(launcher):

FILE: audiocraft/grids/diffusion/_explorers.py
  class DiffusionExplorer (line 12) | class DiffusionExplorer(BaseExplorer):
    method stages (line 15) | def stages(self):
    method get_grid_meta (line 18) | def get_grid_meta(self):
    method get_grid_metrics (line 28) | def get_grid_metrics(self):

FILE: audiocraft/grids/musicgen/_explorers.py
  class LMExplorer (line 14) | class LMExplorer(BaseExplorer):
    method stages (line 17) | def stages(self) -> tp.List[str]:
    method get_grid_metrics (line 20) | def get_grid_metrics(self):
    method process_sheep (line 45) | def process_sheep(self, sheep, history):
  class GenerationEvalExplorer (line 69) | class GenerationEvalExplorer(BaseExplorer):
    method stages (line 72) | def stages(self) -> tp.List[str]:
    method get_grid_metrics (line 75) | def get_grid_metrics(self):

FILE: audiocraft/grids/musicgen/musicgen_base_32khz.py
  function explorer (line 12) | def explorer(launcher):

FILE: audiocraft/grids/musicgen/musicgen_base_cached_32khz.py
  function explorer (line 12) | def explorer(launcher):

FILE: audiocraft/grids/musicgen/musicgen_clapemb_32khz.py
  function explorer (line 12) | def explorer(launcher):

FILE: audiocraft/grids/musicgen/musicgen_melody_32khz.py
  function explorer (line 12) | def explorer(launcher):

FILE: audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py
  function eval (line 26) | def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
  function explorer (line 63) | def explorer(launcher):

FILE: audiocraft/losses/balancer.py
  class Balancer (line 14) | class Balancer:
    method __init__ (line 61) | def __init__(self, weights: tp.Dict[str, float], balance_grads: bool =...
    method metrics (line 74) | def metrics(self):
    method backward (line 77) | def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Te...

FILE: audiocraft/losses/sisnr.py
  function _unfold (line 15) | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Ten...
  function _center (line 31) | def _center(x: torch.Tensor) -> torch.Tensor:
  function _norm2 (line 35) | def _norm2(x: torch.Tensor) -> torch.Tensor:
  class SISNR (line 39) | class SISNR(nn.Module):
    method __init__ (line 51) | def __init__(
    method forward (line 64) | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> tor...

FILE: audiocraft/losses/specloss.py
  class MelSpectrogramWrapper (line 18) | class MelSpectrogramWrapper(nn.Module):
    method __init__ (line 35) | def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_lengt...
    method forward (line 48) | def forward(self, x):
  class MelSpectrogramL1Loss (line 65) | class MelSpectrogramL1Loss(torch.nn.Module):
    method __init__ (line 80) | def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: in...
    method forward (line 89) | def forward(self, x, y):
  class MultiScaleMelSpectrogramLoss (line 96) | class MultiScaleMelSpectrogramLoss(nn.Module):
    method __init__ (line 110) | def __init__(self, sample_rate: int, range_start: int = 6, range_end: ...
    method forward (line 137) | def forward(self, x, y):

FILE: audiocraft/losses/stftloss.py
  function _stft (line 17) | def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
  class SpectralConvergenceLoss (line 45) | class SpectralConvergenceLoss(nn.Module):
    method __init__ (line 48) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
    method forward (line 52) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
  class LogSTFTMagnitudeLoss (line 64) | class LogSTFTMagnitudeLoss(nn.Module):
    method __init__ (line 70) | def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
    method forward (line 74) | def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
  class STFTLosses (line 86) | class STFTLosses(nn.Module):
    method __init__ (line 97) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt...
    method forward (line 109) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch....
  class STFTLoss (line 129) | class STFTLoss(nn.Module):
    method __init__ (line 142) | def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_lengt...
    method forward (line 151) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch....
  class MRSTFTLoss (line 164) | class MRSTFTLoss(nn.Module):
    method __init__ (line 177) | def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_l...
    method forward (line 189) | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

FILE: audiocraft/metrics/chroma_cosinesim.py
  class ChromaCosineSimilarityMetric (line 14) | class ChromaCosineSimilarityMetric(torchmetrics.Metric):
    method __init__ (line 28) | def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, a...
    method update (line 38) | def update(self, preds: torch.Tensor, targets: torch.Tensor,
    method compute (line 69) | def compute(self) -> float:

FILE: audiocraft/metrics/clap_consistency.py
  class TextConsistencyMetric (line 24) | class TextConsistencyMetric(torchmetrics.Metric):
    method update (line 27) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch...
    method compute (line 30) | def compute(self):
  class CLAPTextConsistencyMetric (line 34) | class CLAPTextConsistencyMetric(TextConsistencyMetric):
    method __init__ (line 47) | def __init__(self, model_path: tp.Union[str, Path], model_arch: str = ...
    method _initialize_model (line 55) | def _initialize_model(self, model_path: tp.Union[str, Path], model_arc...
    method _tokenizer (line 63) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
    method update (line 67) | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch...
    method compute (line 81) | def compute(self):

FILE: audiocraft/metrics/fad.py
  class FrechetAudioDistanceMetric (line 29) | class FrechetAudioDistanceMetric(torchmetrics.Metric):
    method __init__ (line 145) | def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path...
    method reset (line 167) | def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
    method update (line 182) | def update(self, preds: torch.Tensor, targets: torch.Tensor,
    method _get_samples_name (line 222) | def _get_samples_name(self, is_background: bool):
    method _create_embedding_beams (line 225) | def _create_embedding_beams(self, is_background: bool, gpu_index: tp.O...
    method _compute_fad_score (line 259) | def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
    method _log_process_result (line 283) | def _log_process_result(self, returncode: int, log_file: tp.Union[Path...
    method _parallel_create_embedding_beams (line 293) | def _parallel_create_embedding_beams(self, num_of_gpus: int):
    method _sequential_create_embedding_beams (line 303) | def _sequential_create_embedding_beams(self):
    method _local_compute_frechet_audio_distance (line 313) | def _local_compute_frechet_audio_distance(self):
    method compute (line 323) | def compute(self) -> float:

FILE: audiocraft/metrics/kld.py
  class _patch_passt_stft (line 22) | class _patch_passt_stft:
    method __init__ (line 24) | def __init__(self):
    method __enter__ (line 27) | def __enter__(self):
    method __exit__ (line 32) | def __exit__(self, *exc):
  function kl_divergence (line 36) | def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, ...
  class KLDivergenceMetric (line 53) | class KLDivergenceMetric(torchmetrics.Metric):
    method __init__ (line 62) | def __init__(self):
    method _get_label_distribution (line 69) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
    method update (line 82) | def update(self, preds: torch.Tensor, targets: torch.Tensor,
    method compute (line 105) | def compute(self) -> dict:
  class PasstKLDivergenceMetric (line 116) | class PasstKLDivergenceMetric(KLDivergenceMetric):
    method __init__ (line 131) | def __init__(self, pretrained_length: tp.Optional[float] = None):
    method _initialize_model (line 135) | def _initialize_model(self, pretrained_length: tp.Optional[float] = No...
    method _load_base_model (line 145) | def _load_base_model(self, pretrained_length: tp.Optional[float]):
    method _process_audio (line 172) | def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len:...
    method _get_model_preds (line 187) | def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
    method _get_label_distribution (line 198) | def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,

FILE: audiocraft/metrics/rvm.py
  function db_to_scale (line 13) | def db_to_scale(volume: tp.Union[float, torch.Tensor]):
  function scale_to_db (line 17) | def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
  class RelativeVolumeMel (line 22) | class RelativeVolumeMel(nn.Module):
    method __init__ (line 69) | def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: ...
    method forward (line 84) | def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) ...

FILE: audiocraft/metrics/visqol.py
  class ViSQOL (line 22) | class ViSQOL:
    method __init__ (line 56) | def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
    method _get_target_sr (line 67) | def _get_target_sr(self, mode: str) -> int:
    method _prepare_files (line 75) | def _prepare_files(
    method _flush_files (line 132) | def _flush_files(self, tmp_dir: tp.Union[Path, str]):
    method _collect_moslqo_score (line 136) | def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str])...
    method _collect_debug_data (line 146) | def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) ->...
    method visqol_model (line 153) | def visqol_model(self):
    method _run_visqol (line 156) | def _run_visqol(
    method __call__ (line 181) | def __call__(

FILE: audiocraft/models/audiogen.py
  class AudioGen (line 25) | class AudioGen:
    method __init__ (line 36) | def __init__(self, name: str, compression_model: CompressionModel, lm:...
    method frame_rate (line 59) | def frame_rate(self) -> float:
    method sample_rate (line 64) | def sample_rate(self) -> int:
    method audio_channels (line 69) | def audio_channels(self) -> int:
    method get_pretrained (line 74) | def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
    method set_generation_params (line 97) | def set_generation_params(self, use_sampling: bool = True, top_k: int ...
    method set_custom_progress_callback (line 129) | def set_custom_progress_callback(self, progress_callback: tp.Optional[...
    method generate (line 133) | def generate(self, descriptions: tp.List[str], progress: bool = False)...
    method generate_continuation (line 144) | def generate_continuation(self, prompt: torch.Tensor, prompt_sample_ra...
    method _prepare_tokens_and_attributes (line 168) | def _prepare_tokens_and_attributes(
    method _generate_tokens (line 193) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
    method to (line 273) | def to(self, device: str):

FILE: audiocraft/models/builders.py
  function get_quantizer (line 43) | def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: ...
  function get_encodec_autoencoder (line 54) | def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
  function get_compression_model (line 68) | def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
  function get_lm_model (line 86) | def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
  function get_conditioner_provider (line 122) | def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig)...
  function get_condition_fuser (line 159) | def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
  function get_codebooks_pattern_provider (line 169) | def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) ...
  function get_debug_compression_model (line 184) | def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
  function get_diffusion_model (line 211) | def get_diffusion_model(cfg: omegaconf.DictConfig):
  function get_processor (line 219) | def get_processor(cfg, sample_rate: int = 24000):
  function get_debug_lm_model (line 230) | def get_debug_lm_model(device='cpu'):
  function get_wrapped_compression_model (line 248) | def get_wrapped_compression_model(

FILE: audiocraft/models/encodec.py
  class CompressionModel (line 27) | class CompressionModel(ABC, nn.Module):
    method forward (line 33) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
    method encode (line 37) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
    method decode (line 42) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
    method decode_latent (line 47) | def decode_latent(self, codes: torch.Tensor):
    method channels (line 53) | def channels(self) -> int:
    method frame_rate (line 58) | def frame_rate(self) -> float:
    method sample_rate (line 63) | def sample_rate(self) -> int:
    method cardinality (line 68) | def cardinality(self) -> int:
    method num_codebooks (line 73) | def num_codebooks(self) -> int:
    method total_codebooks (line 78) | def total_codebooks(self) -> int:
    method set_num_codebooks (line 82) | def set_num_codebooks(self, n: int):
    method get_pretrained (line 87) | def get_pretrained(
  class EncodecModel (line 124) | class EncodecModel(CompressionModel):
    method __init__ (line 143) | def __init__(self,
    method total_codebooks (line 167) | def total_codebooks(self):
    method num_codebooks (line 172) | def num_codebooks(self):
    method set_num_codebooks (line 176) | def set_num_codebooks(self, n: int):
    method cardinality (line 181) | def cardinality(self):
    method preprocess (line 185) | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Opt...
    method postprocess (line 197) | def postprocess(self,
    method forward (line 205) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
    method encode (line 222) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
    method decode (line 239) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
    method decode_latent (line 256) | def decode_latent(self, codes: torch.Tensor):
  class DAC (line 261) | class DAC(CompressionModel):
    method __init__ (line 262) | def __init__(self, model_type: str = "44khz"):
    method forward (line 273) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
    method encode (line 277) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
    method decode (line 281) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
    method decode_latent (line 286) | def decode_latent(self, codes: torch.Tensor):
    method channels (line 291) | def channels(self) -> int:
    method frame_rate (line 295) | def frame_rate(self) -> float:
    method sample_rate (line 299) | def sample_rate(self) -> int:
    method cardinality (line 303) | def cardinality(self) -> int:
    method num_codebooks (line 307) | def num_codebooks(self) -> int:
    method total_codebooks (line 311) | def total_codebooks(self) -> int:
    method set_num_codebooks (line 314) | def set_num_codebooks(self, n: int):
  class HFEncodecCompressionModel (line 322) | class HFEncodecCompressionModel(CompressionModel):
    method __init__ (line 325) | def __init__(self, model: HFEncodecModel):
    method forward (line 339) | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
    method encode (line 343) | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optiona...
    method decode (line 351) | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor]...
    method decode_latent (line 359) | def decode_latent(self, codes: torch.Tensor):
    method channels (line 364) | def channels(self) -> int:
    method frame_rate (line 368) | def frame_rate(self) -> float:
    method sample_rate (line 373) | def sample_rate(self) -> int:
    method cardinality (line 377) | def cardinality(self) -> int:
    method num_codebooks (line 381) | def num_codebooks(self) -> int:
    method total_codebooks (line 385) | def total_codebooks(self) -> int:
    method set_num_codebooks (line 388) | def set_num_codebooks(self, n: int):

FILE: audiocraft/models/lm.py
  function get_init_fn (line 36) | def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int...
  function init_layer (line 64) | def init_layer(m: nn.Module,
  class ScaledEmbedding (line 97) | class ScaledEmbedding(nn.Embedding):
    method __init__ (line 100) | def __init__(self, *args, lr=None, **kwargs):
    method make_optim_group (line 104) | def make_optim_group(self):
  class LMOutput (line 112) | class LMOutput:
  class LMModel (line 119) | class LMModel(StreamingModule):
    method __init__ (line 144) | def __init__(self, pattern_provider: CodebooksPatternProvider, conditi...
    method _init_weights (line 178) | def _init_weights(self, weight_init: tp.Optional[str], depthwise_init:...
    method special_token_id (line 213) | def special_token_id(self) -> int:
    method num_codebooks (line 217) | def num_codebooks(self) -> int:
    method forward (line 220) | def forward(self, sequence: torch.Tensor,
    method compute_predictions (line 264) | def compute_predictions(
    method _sample_next_token (line 309) | def _sample_next_token(self,
    method generate (line 381) | def generate(self,

FILE: audiocraft/models/loaders.py
  function get_audiocraft_cache_dir (line 34) | def get_audiocraft_cache_dir() -> tp.Optional[str]:
  function _get_state_dict (line 38) | def _get_state_dict(
  function load_compression_model_ckpt (line 67) | def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], ...
  function load_compression_model (line 71) | def load_compression_model(file_or_url_or_id: tp.Union[Path, str], devic...
  function load_lm_model_ckpt (line 83) | def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir...
  function _delete_param (line 87) | def _delete_param(cfg: DictConfig, full_name: str):
  function load_lm_model (line 100) | def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', ...
  function load_mbd_ckpt (line 118) | def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp....
  function load_diffusion_models (line 122) | def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device...

FILE: audiocraft/models/multibanddiffusion.py
  class DiffusionProcess (line 25) | class DiffusionProcess:
    method __init__ (line 32) | def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule...
    method generate (line 38) | def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
  class MultiBandDiffusion (line 50) | class MultiBandDiffusion:
    method __init__ (line 57) | def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: Compre...
    method sample_rate (line 63) | def sample_rate(self) -> int:
    method get_mbd_musicgen (line 67) | def get_mbd_musicgen(device=None):
    method get_mbd_24khz (line 82) | def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
    method get_condition (line 116) | def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch....
    method get_emb (line 129) | def get_emb(self, codes: torch.Tensor):
    method generate (line 136) | def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = ...
    method re_eq (line 154) | def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 3...
    method regenerate (line 170) | def regenerate(self, wav: torch.Tensor, sample_rate: int):
    method tokens_to_wav (line 185) | def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):

FILE: audiocraft/models/musicgen.py
  class MusicGen (line 39) | class MusicGen:
    method __init__ (line 50) | def __init__(self, name: str, compression_model: CompressionModel, lm:...
    method frame_rate (line 73) | def frame_rate(self) -> float:
    method sample_rate (line 78) | def sample_rate(self) -> int:
    method audio_channels (line 83) | def audio_channels(self) -> int:
    method get_pretrained (line 88) | def get_pretrained(name: str = 'GrandaddyShmax/musicgen-melody', devic...
    method set_generation_params (line 118) | def set_generation_params(self, use_sampling: bool = True, top_k: int ...
    method set_custom_progress_callback (line 150) | def set_custom_progress_callback(self, progress_callback: tp.Optional[...
    method generate_unconditional (line 154) | def generate_unconditional(self, num_samples: int, progress: bool = Fa...
    method generate (line 168) | def generate(self, descriptions: tp.List[str], progress: bool = False,...
    method generate_with_chroma (line 183) | def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs...
    method generate_continuation (line 218) | def generate_continuation(self, prompt: torch.Tensor, prompt_sample_ra...
    method _prepare_tokens_and_attributes (line 246) | def _prepare_tokens_and_attributes(
    method _generate_tokens (line 303) | def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
    method generate_audio (line 399) | def generate_audio(self, gen_tokens: torch.Tensor):
    method to (line 406) | def to(self, device: str):

FILE: audiocraft/models/unet.py
  class Output (line 21) | class Output:
  function get_model (line 25) | def get_model(cfg, channels: int, side: int, num_steps: int):
  class ResBlock (line 33) | class ResBlock(nn.Module):
    method __init__ (line 34) | def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
    method forward (line 52) | def forward(self, x):
  class DecoderLayer (line 58) | class DecoderLayer(nn.Module):
    method __init__ (line 59) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int...
    method forward (line 72) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class EncoderLayer (line 80) | class EncoderLayer(nn.Module):
    method __init__ (line 81) | def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int...
    method forward (line 94) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class BLSTM (line 107) | class BLSTM(nn.Module):
    method __init__ (line 110) | def __init__(self, dim, layers=2):
    method forward (line 115) | def forward(self, x):
  class DiffusionUnet (line 123) | class DiffusionUnet(nn.Module):
    method __init__ (line 124) | def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, gr...
    method forward (line 163) | def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], ...

FILE: audiocraft/modules/activations.py
  class CustomGLU (line 13) | class CustomGLU(nn.Module):
    method __init__ (line 33) | def __init__(self, activation: nn.Module, dim: int = -1):
    method forward (line 38) | def forward(self, x: Tensor):
  class SwiGLU (line 44) | class SwiGLU(CustomGLU):
    method __init__ (line 52) | def __init__(self, dim: int = -1):
  class GeGLU (line 56) | class GeGLU(CustomGLU):
    method __init__ (line 64) | def __init__(self, dim: int = -1):
  class ReGLU (line 68) | class ReGLU(CustomGLU):
    method __init__ (line 76) | def __init__(self, dim: int = -1):
  function get_activation_fn (line 80) | def get_activation_fn(

FILE: audiocraft/modules/chroma.py
  class ChromaExtractor (line 16) | class ChromaExtractor(nn.Module):
    method __init__ (line 29) | def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: i...
    method forward (line 46) | def forward(self, wav: torch.Tensor) -> torch.Tensor:

FILE: audiocraft/modules/codebooks_patterns.py
  class Pattern (line 22) | class Pattern:
    method __post_init__ (line 50) | def __post_init__(self):
    method _validate_layout (line 58) | def _validate_layout(self):
    method num_sequence_steps (line 80) | def num_sequence_steps(self):
    method max_delay (line 84) | def max_delay(self):
    method valid_layout (line 92) | def valid_layout(self):
    method get_sequence_coords_with_timestep (line 96) | def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int...
    method get_steps_with_timestep (line 111) | def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) ...
    method get_first_step_with_timesteps (line 114) | def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = ...
    method _build_pattern_sequence_scatter_indexes (line 118) | def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q:...
    method build_pattern_sequence (line 152) | def build_pattern_sequence(self, z: torch.Tensor, special_token: int, ...
    method _build_reverted_sequence_scatter_indexes (line 179) | def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int...
    method revert_pattern_sequence (line 223) | def revert_pattern_sequence(self, s: torch.Tensor, special_token: int,...
    method revert_pattern_logits (line 248) | def revert_pattern_logits(self, logits: torch.Tensor, special_token: f...
  class CodebooksPatternProvider (line 270) | class CodebooksPatternProvider(ABC):
    method __init__ (line 288) | def __init__(self, n_q: int, cached: bool = True):
    method get_pattern (line 294) | def get_pattern(self, timesteps: int) -> Pattern:
  class DelayedPatternProvider (line 303) | class DelayedPatternProvider(CodebooksPatternProvider):
    method __init__ (line 326) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
    method get_pattern (line 337) | def get_pattern(self, timesteps: int) -> Pattern:
  class ParallelPatternProvider (line 356) | class ParallelPatternProvider(DelayedPatternProvider):
    method __init__ (line 364) | def __init__(self, n_q: int):
  class UnrolledPatternProvider (line 368) | class UnrolledPatternProvider(CodebooksPatternProvider):
    method __init__ (line 419) | def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = N...
    method _build_flattened_codebooks (line 433) | def _build_flattened_codebooks(self, delays: tp.List[int], flattening:...
    method _num_inner_steps (line 453) | def _num_inner_steps(self):
    method num_virtual_steps (line 458) | def num_virtual_steps(self, timesteps: int) -> int:
    method get_pattern (line 461) | def get_pattern(self, timesteps: int) -> Pattern:
  class VALLEPattern (line 489) | class VALLEPattern(CodebooksPatternProvider):
    method __init__ (line 498) | def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
    method get_pattern (line 506) | def get_pattern(self, timesteps: int) -> Pattern:
  class MusicLMPattern (line 521) | class MusicLMPattern(CodebooksPatternProvider):
    method __init__ (line 529) | def __init__(self, n_q: int, group_by: int = 2):
    method get_pattern (line 533) | def get_pattern(self, timesteps: int) -> Pattern:

FILE: audiocraft/modules/conditioners.py
  class WavCondition (line 46) | class WavCondition(tp.NamedTuple):
  class JointEmbedCondition (line 54) | class JointEmbedCondition(tp.NamedTuple):
  class ConditioningAttributes (line 64) | class ConditioningAttributes:
    method __getitem__ (line 69) | def __getitem__(self, item):
    method text_attributes (line 73) | def text_attributes(self):
    method wav_attributes (line 77) | def wav_attributes(self):
    method joint_embed_attributes (line 81) | def joint_embed_attributes(self):
    method attributes (line 85) | def attributes(self):
    method to_flat_dict (line 92) | def to_flat_dict(self):
    method from_flat_dict (line 100) | def from_flat_dict(cls, x):
  class SegmentWithAttributes (line 108) | class SegmentWithAttributes(SegmentInfo):
    method to_condition_attributes (line 113) | def to_condition_attributes(self) -> ConditioningAttributes:
  function nullify_condition (line 117) | def nullify_condition(condition: ConditionType, dim: int = 1):
  function nullify_wav (line 144) | def nullify_wav(cond: WavCondition) -> WavCondition:
  function nullify_joint_embed (line 163) | def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
  class Tokenizer (line 180) | class Tokenizer:
    method __call__ (line 184) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch...
  class WhiteSpaceTokenizer (line 188) | class WhiteSpaceTokenizer(Tokenizer):
    method __init__ (line 197) | def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_...
    method __call__ (line 210) | def __call__(self, texts: tp.List[tp.Optional[str]],
  class NoopTokenizer (line 256) | class NoopTokenizer(Tokenizer):
    method __init__ (line 266) | def __init__(self, n_bins: int, pad_idx: int = 0):
    method __call__ (line 270) | def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch...
  class BaseConditioner (line 286) | class BaseConditioner(nn.Module):
    method __init__ (line 296) | def __init__(self, dim: int, output_dim: int):
    method tokenize (line 302) | def tokenize(self, *args, **kwargs) -> tp.Any:
    method forward (line 310) | def forward(self, inputs: tp.Any) -> ConditionType:
  class TextConditioner (line 323) | class TextConditioner(BaseConditioner):
  class LUTConditioner (line 327) | class LUTConditioner(TextConditioner):
    method __init__ (line 337) | def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: ...
    method tokenize (line 348) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Ten...
    method forward (line 354) | def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> Con...
  class T5Conditioner (line 362) | class T5Conditioner(TextConditioner):
    method __init__ (line 390) | def __init__(self, name: str, output_dim: int, finetune: bool, device:...
    method tokenize (line 430) | def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch...
    method forward (line 449) | def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
  class WaveformConditioner (line 458) | class WaveformConditioner(BaseConditioner):
    method __init__ (line 469) | def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.d...
    method tokenize (line 473) | def tokenize(self, x: WavCondition) -> WavCondition:
    method _get_wav_embedding (line 478) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
    method _downsampling_factor (line 482) | def _downsampling_factor(self):
    method forward (line 486) | def forward(self, x: WavCondition) -> ConditionType:
  class ChromaStemConditioner (line 509) | class ChromaStemConditioner(WaveformConditioner):
    method __init__ (line 531) | def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, r...
    method _downsampling_factor (line 554) | def _downsampling_factor(self) -> int:
    method _load_eval_wavs (line 557) | def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) ->...
    method reset_eval_wavs (line 578) | def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
    method has_eval_wavs (line 581) | def has_eval_wavs(self) -> bool:
    method _sample_eval_wavs (line 584) | def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
    method _get_chroma_len (line 593) | def _get_chroma_len(self) -> int:
    method _get_stemmed_wav (line 600) | def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> tor...
    method _extract_chroma (line 614) | def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
    method _compute_wav_embedding (line 620) | def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) ...
    method _get_full_chroma_for_cache (line 630) | def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: Wav...
    method _extract_chroma_chunk (line 638) | def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondi...
    method _get_wav_embedding (line 654) | def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
    method tokenize (line 688) | def tokenize(self, x: WavCondition) -> WavCondition:
  class JointEmbeddingConditioner (line 698) | class JointEmbeddingConditioner(BaseConditioner):
    method __init__ (line 712) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ...
    method _get_embed (line 731) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,...
    method forward (line 740) | def forward(self, x: JointEmbedCondition) -> ConditionType:
    method tokenize (line 755) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
  class CLAPEmbeddingConditioner (line 759) | class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
    method __init__ (line 786) | def __init__(self, dim: int, output_dim: int, device: str, attribute: ...
    method _tokenizer (line 825) | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
    method _compute_text_embedding (line 829) | def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
    method _get_text_embedding_for_cache (line 841) | def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
    method _preprocess_wav (line 848) | def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sam...
    method _compute_wav_embedding (line 869) | def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
    method _get_wav_embedding_for_cache (line 904) | def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
    method _extract_wav_embedding_chunk (line 920) | def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: Jo...
    method _get_text_embedding (line 941) | def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
    method _get_wav_embedding (line 955) | def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
    method tokenize (line 968) | def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
    method _get_embed (line 981) | def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor,...
  function dropout_condition (line 994) | def dropout_condition(sample: ConditioningAttributes, condition_type: st...
  class DropoutModule (line 1025) | class DropoutModule(nn.Module):
    method __init__ (line 1027) | def __init__(self, seed: int = 1234):
  class AttributeDropout (line 1033) | class AttributeDropout(DropoutModule):
    method __init__ (line 1050) | def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eva...
    method forward (line 1058) | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List...
    method __repr__ (line 1076) | def __repr__(self):
  class ClassifierFreeGuidanceDropout (line 1080) | class ClassifierFreeGuidanceDropout(DropoutModule):
    method __init__ (line 1088) | def __init__(self, p: float, seed: int = 1234):
    method forward (line 1092) | def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List...
    method __repr__ (line 1115) | def __repr__(self):
  class ConditioningProvider (line 1119) | class ConditioningProvider(nn.Module):
    method __init__ (line 1126) | def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device...
    method joint_embed_conditions (line 1132) | def joint_embed_conditions(self):
    method has_joint_embed_conditions (line 1136) | def has_joint_embed_conditions(self):
    method text_conditions (line 1140) | def text_conditions(self):
    method wav_conditions (line 1144) | def wav_conditions(self):
    method has_wav_condition (line 1148) | def has_wav_condition(self):
    method tokenize (line 1151) | def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict...
    method forward (line 1179) | def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, Con...
    method _collate_text (line 1197) | def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> t...
    method _collate_wavs (line 1224) | def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> t...
    method _collate_joint_embeds (line 1268) | def _collate_joint_embeds(self, samples: tp.List[ConditioningAttribute...
  class ConditionFuser (line 1322) | class ConditionFuser(StreamingModule):
    method __init__ (line 1339) | def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attent...
    method forward (line 1353) | def forward(

FILE: audiocraft/modules/conv.py
  function apply_parametrization_norm (line 21) | def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
  function get_norm_module (line 33) | def get_norm_module(module: nn.Module, causal: bool = False, norm: str =...
  function get_extra_padding_for_conv1d (line 47) | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stri...
  function pad_for_conv1d (line 56) | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, paddi...
  function pad1d (line 71) | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'co...
  function unpad1d (line 91) | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
  class NormConv1d (line 100) | class NormConv1d(nn.Module):
    method __init__ (line 104) | def __init__(self, *args, causal: bool = False, norm: str = 'none',
    method forward (line 111) | def forward(self, x):
  class NormConv2d (line 117) | class NormConv2d(nn.Module):
    method __init__ (line 121) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str...
    method forward (line 127) | def forward(self, x):
  class NormConvTranspose1d (line 133) | class NormConvTranspose1d(nn.Module):
    method __init__ (line 137) | def __init__(self, *args, causal: bool = False, norm: str = 'none',
    method forward (line 144) | def forward(self, x):
  class NormConvTranspose2d (line 150) | class NormConvTranspose2d(nn.Module):
    method __init__ (line 154) | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str...
    method forward (line 159) | def forward(self, x):
  class StreamableConv1d (line 165) | class StreamableConv1d(nn.Module):
    method __init__ (line 169) | def __init__(self, in_channels: int, out_channels: int,
    method forward (line 185) | def forward(self, x):
  class StreamableConvTranspose1d (line 204) | class StreamableConvTranspose1d(nn.Module):
    method __init__ (line 208) | def __init__(self, in_channels: int, out_channels: int,
    method forward (line 221) | def forward(self, x):

FILE: audiocraft/modules/diffusion_schedule.py
  function betas_from_alpha_bar (line 20) | def betas_from_alpha_bar(alpha_bar):
  class SampleProcessor (line 25) | class SampleProcessor(torch.nn.Module):
    method project_sample (line 26) | def project_sample(self, x: torch.Tensor):
    method return_sample (line 30) | def return_sample(self, z: torch.Tensor):
  class MultiBandProcessor (line 35) | class MultiBandProcessor(SampleProcessor):
    method __init__ (line 57) | def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
    method mean (line 77) | def mean(self):
    method std (line 82) | def std(self):
    method target_std (line 87) | def target_std(self):
    method project_sample (line 91) | def project_sample(self, x: torch.Tensor):
    method return_sample (line 104) | def return_sample(self, x: torch.Tensor):
  class NoiseSchedule (line 112) | class NoiseSchedule:
    method __init__ (line 127) | def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_s...
    method get_beta (line 149) | def get_beta(self, step: tp.Union[int, torch.Tensor]):
    method get_initial_noise (line 155) | def get_initial_noise(self, x: torch.Tensor):
    method get_alpha_bar (line 160) | def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]]...
    method get_training_item (line 169) | def get_training_item(self, x: torch.Tensor, tensor_step: bool = False...
    method generate (line 192) | def generate(self, model: torch.nn.Module, initial: tp.Optional[torch....
    method generate_subsampled (line 238) | def generate_subsampled(self, model: torch.nn.Module, initial: torch.T...

FILE: audiocraft/modules/lstm.py
  class StreamableLSTM (line 10) | class StreamableLSTM(nn.Module):
    method __init__ (line 14) | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = T...
    method forward (line 19) | def forward(self, x):

FILE: audiocraft/modules/rope.py
  class XPos (line 13) | class XPos(nn.Module):
    method __init__ (line 24) | def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int =...
    method get_decay (line 38) | def get_decay(self, start: int, end: int):
  class RotaryEmbedding (line 49) | class RotaryEmbedding(nn.Module):
    method __init__ (line 60) | def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool =...
    method get_rotation (line 75) | def get_rotation(self, start: int, end: int):
    method rotate (line 84) | def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool =...
    method rotate_qk (line 103) | def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int...

FILE: audiocraft/modules/seanet.py
  class SEANetResnetBlock (line 16) | class SEANetResnetBlock(nn.Module):
    method __init__ (line 33) | def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dila...
    method forward (line 59) | def forward(self, x):
  class SEANetEncoder (line 63) | class SEANetEncoder(nn.Module):
    method __init__ (line 91) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:...
    method forward (line 152) | def forward(self, x):
  class SEANetDecoder (line 156) | class SEANetDecoder(nn.Module):
    method __init__ (line 186) | def __init__(self, channels: int = 1, dimension: int = 128, n_filters:...
    method forward (line 256) | def forward(self, z):

FILE: audiocraft/modules/streaming.py
  class StreamingModule (line 20) | class StreamingModule(nn.Module):
    method __init__ (line 43) | def __init__(self) -> None:
    method _apply_named_streaming (line 48) | def _apply_named_streaming(self, fn: tp.Any):
    method _set_streaming (line 53) | def _set_streaming(self, streaming: bool):
    method streaming (line 59) | def streaming(self):
    method reset_streaming (line 68) | def reset_streaming(self):
    method get_streaming_state (line 75) | def get_streaming_state(self) -> State:
    method set_streaming_state (line 88) | def set_streaming_state(self, state: State):
    method flush (line 107) | def flush(self, x: tp.Optional[torch.Tensor] = None):
  class StreamingSequential (line 122) | class StreamingSequential(StreamingModule, nn.Sequential):
    method flush (line 125) | def flush(self, x: tp.Optional[torch.Tensor] = None):

FILE: audiocraft/modules/transformer.py
  function set_efficient_attention_backend (line 31) | def set_efficient_attention_backend(backend: str = 'torch'):
  function _get_attention_time_dimension (line 38) | def _get_attention_time_dimension() -> int:
  function _is_profiled (line 45) | def _is_profiled() -> bool:
  function create_norm_fn (line 54) | def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
  function create_sin_embedding (line 70) | def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: ...
  function expand_repeated_kv (line 92) | def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
  class LayerScale (line 112) | class LayerScale(nn.Module):
    method __init__ (line 123) | def __init__(self, channels: int, init: float = 1e-4, channel_last: bo...
    method forward (line 131) | def forward(self, x: torch.Tensor):
  class StreamingMultiheadAttention (line 138) | class StreamingMultiheadAttention(StreamingModule):
    method __init__ (line 164) | def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0....
    method _load_from_state_dict (line 224) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
    method _get_mask (line 233) | def _get_mask(self, current_steps: int, device: torch.device, dtype: t...
    method _complete_kv (line 266) | def _complete_kv(self, k, v):
    method _apply_rope (line 300) | def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
    method forward (line 316) | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch...
  class StreamingTransformerLayer (line 445) | class StreamingTransformerLayer(nn.TransformerEncoderLayer):
    method __init__ (line 479) | def __init__(self, d_model: int, num_heads: int, dim_feedforward: int ...
    method _cross_attention_block (line 533) | def _cross_attention_block(self, src: torch.Tensor,
    method forward (line 541) | def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tenso...
  class StreamingTransformer (line 568) | class StreamingTransformer(StreamingModule):
    method __init__ (line 605) | def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_...
    method _apply_layer (line 654) | def _apply_layer(self, layer, *args, **kwargs):
    method forward (line 685) | def forward(self, x: torch.Tensor, *args, **kwargs):
    method make_optim_group (line 707) | def make_optim_group(self):
  function _verify_xformers_memory_efficient_compat (line 718) | def _verify_xformers_memory_efficient_compat():
  function _verify_xformers_internal_compat (line 732) | def _verify_xformers_internal_compat():
  function _is_custom (line 746) | def _is_custom(custom: bool, memory_efficient: bool):

FILE: audiocraft/optim/cosine_lr_scheduler.py
  class CosineLRScheduler (line 13) | class CosineLRScheduler(_LRScheduler):
    method __init__ (line 23) | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_step...
    method _get_sched_lr (line 33) | def _get_sched_lr(self, lr: float, step: int):
    method get_lr (line 47) | def get_lr(self):

FILE: audiocraft/optim/dadam.py
  function to_real (line 23) | def to_real(x):
  class DAdaptAdam (line 30) | class DAdaptAdam(torch.optim.Optimizer):
    method __init__ (line 62) | def __init__(self, params, lr=1.0,
    method supports_memory_efficient_fp16 (line 99) | def supports_memory_efficient_fp16(self):
    method supports_flat_params (line 103) | def supports_flat_params(self):
    method step (line 106) | def step(self, closure=None):

FILE: audiocraft/optim/ema.py
  function _get_all_non_persistent_buffers_set (line 17) | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "...
  function _get_named_tensors (line 32) | def _get_named_tensors(module: nn.Module):
  class ModuleDictEMA (line 40) | class ModuleDictEMA:
    method __init__ (line 45) | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
    method _init (line 55) | def _init(self):
    method step (line 64) | def step(self):
    method state_dict (line 78) | def state_dict(self):
    method load_state_dict (line 81) | def load_state_dict(self, state):

FILE: audiocraft/optim/fsdp.py
  function is_fsdp_used (line 22) | def is_fsdp_used() -> bool:
  function is_sharded_tensor (line 32) | def is_sharded_tensor(x: tp.Any) -> bool:
  function switch_to_full_state_dict (line 37) | def switch_to_full_state_dict(models: tp.List[FSDP]):
  function wrap_with_fsdp (line 51) | def wrap_with_fsdp(cfg, model: torch.nn.Module,
  function purge_fsdp (line 120) | def purge_fsdp(model: FSDP):
  class _FSDPFixStateDict (line 138) | class _FSDPFixStateDict(FSDP):
    method _name_without_fsdp_prefix (line 140) | def _name_without_fsdp_prefix(name: str) -> str:
    method state_dict (line 146) | def state_dict(self) -> tp.Dict[str, tp.Any]:  # type: ignore
    method load_state_dict (line 153) | def load_state_dict(self, state: tp.Dict[str, tp.Any]):  # type: ignore
  function _fix_post_backward_hook (line 175) | def _fix_post_backward_hook():

FILE: audiocraft/optim/inverse_sqrt_lr_scheduler.py
  class InverseSquareRootLRScheduler (line 13) | class InverseSquareRootLRScheduler(_LRScheduler):
    method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini...
    method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int):
    method get_lr (line 37) | def get_lr(self):

FILE: audiocraft/optim/linear_warmup_lr_scheduler.py
  class LinearWarmupLRScheduler (line 13) | class LinearWarmupLRScheduler(_LRScheduler):
    method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_ini...
    method _get_sched_lr (line 27) | def _get_sched_lr(self, lr: float, step: int):
    method get_lr (line 34) | def get_lr(self):

FILE: audiocraft/optim/polynomial_decay_lr_scheduler.py
  class PolynomialDecayLRScheduler (line 11) | class PolynomialDecayLRScheduler(_LRScheduler):
    method __init__ (line 22) | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_step...
    method _get_sched_lr (line 31) | def _get_sched_lr(self, lr: float, step: int):
    method get_lr (line 46) | def get_lr(self):

FILE: audiocraft/quantization/base.py
  class QuantizedResult (line 19) | class QuantizedResult:
  class BaseQuantizer (line 27) | class BaseQuantizer(nn.Module):
    method forward (line 31) | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
    method encode (line 40) | def encode(self, x: torch.Tensor) -> torch.Tensor:
    method decode (line 44) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
    method total_codebooks (line 49) | def total_codebooks(self):
    method num_codebooks (line 54) | def num_codebooks(self):
    method set_num_codebooks (line 58) | def set_num_codebooks(self, n: int):
  class DummyQuantizer (line 63) | class DummyQuantizer(BaseQuantizer):
    method __init__ (line 66) | def __init__(self):
    method forward (line 69) | def forward(self, x: torch.Tensor, frame_rate: int):
    method encode (line 73) | def encode(self, x: torch.Tensor) -> torch.Tensor:
    method decode (line 80) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
    method total_codebooks (line 88) | def total_codebooks(self):
    method num_codebooks (line 93) | def num_codebooks(self):
    method set_num_codebooks (line 97) | def set_num_codebooks(self, n: int):

FILE: audiocraft/quantization/core_vq.py
  function exists (line 16) | def exists(val: tp.Optional[tp.Any]) -> bool:
  function default (line 20) | def default(val: tp.Any, d: tp.Any) -> tp.Any:
  function l2norm (line 24) | def l2norm(t):
  function ema_inplace (line 28) | def ema_inplace(moving_avg, new, decay: float):
  function laplace_smoothing (line 32) | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
  function uniform_init (line 36) | def uniform_init(*shape: int):
  function sample_vectors (line 42) | def sample_vectors(samples, num: int):
  function kmeans (line 53) | def kmeans(samples, num_clusters: int, num_iters: int = 10):
  function orthogonal_loss_fn (line 78) | def orthogonal_loss_fn(t):
  class EuclideanCodebook (line 87) | class EuclideanCodebook(nn.Module):
    method __init__ (line 103) | def __init__(
    method init_embed_ (line 130) | def init_embed_(self, data):
    method replace_ (line 142) | def replace_(self, samples, mask):
    method expire_codes_ (line 148) | def expire_codes_(self, batch_samples):
    method preprocess (line 160) | def preprocess(self, x):
    method quantize (line 164) | def quantize(self, x):
    method postprocess_emb (line 174) | def postprocess_emb(self, embed_ind, shape):
    method dequantize (line 177) | def dequantize(self, embed_ind):
    method encode (line 181) | def encode(self, x):
    method decode (line 191) | def decode(self, embed_ind):
    method forward (line 195) | def forward(self, x):
  class VectorQuantization (line 222) | class VectorQuantization(nn.Module):
    method __init__ (line 245) | def __init__(
    method codebook (line 284) | def codebook(self):
    method inited (line 288) | def inited(self):
    method _preprocess (line 291) | def _preprocess(self, x):
    method _postprocess (line 296) | def _postprocess(self, quantize):
    method encode (line 301) | def encode(self, x):
    method decode (line 307) | def decode(self, embed_ind):
    method forward (line 313) | def forward(self, x):
  class ResidualVectorQuantization (line 352) | class ResidualVectorQuantization(nn.Module):
    method __init__ (line 357) | def __init__(self, *, num_quantizers, **kwargs):
    method forward (line 363) | def forward(self, x, n_q: tp.Optional[int] = None):
    method encode (line 382) | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> tor...
    method decode (line 394) | def decode(self, q_indices: torch.Tensor) -> torch.Tensor:

FILE: audiocraft/quantization/vq.py
  class ResidualVectorQuantizer (line 16) | class ResidualVectorQuantizer(BaseQuantizer):
    method __init__ (line 35) | def __init__(
    method forward (line 76) | def forward(self, x: torch.Tensor, frame_rate: int):
    method encode (line 87) | def encode(self, x: torch.Tensor) -> torch.Tensor:
    method decode (line 98) | def decode(self, codes: torch.Tensor) -> torch.Tensor:
    method total_codebooks (line 106) | def total_codebooks(self):
    method num_codebooks (line 110) | def num_codebooks(self):
    method set_num_codebooks (line 113) | def set_num_codebooks(self, n: int):

FILE: audiocraft/solvers/audiogen.py
  class AudioGenSolver (line 10) | class AudioGenSolver(musicgen.MusicGenSolver):

FILE: audiocraft/solvers/base.py
  class StandardSolver (line 27) | class StandardSolver(ABC, flashy.BaseSolver):
    method __init__ (line 38) | def __init__(self, cfg: omegaconf.DictConfig):
    method autocast (line 98) | def autocast(self):
    method _get_state_source (line 102) | def _get_state_source(self, name) -> flashy.state.StateDictSource:
    method best_metric_name (line 107) | def best_metric_name(self) -> tp.Optional[str]:
    method register_best_state (line 114) | def register_best_state(self, *args: str):
    method register_ema (line 127) | def register_ema(self, *args: str):
    method wrap_with_fsdp (line 141) | def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
    method update_best_state_from_stage (line 147) | def update_best_state_from_stage(self, stage_name: str = 'valid'):
    method _load_new_state_dict (line 189) | def _load_new_state_dict(self, state_dict: dict) -> dict:
    method swap_best_state (line 198) | def swap_best_state(self):
    method swap_ema_state (line 210) | def swap_ema_state(self):
    method is_training (line 226) | def is_training(self):
    method log_model_summary (line 229) | def log_model_summary(self, model: nn.Module):
    method build_model (line 236) | def build_model(self):
    method initialize_ema (line 240) | def initialize_ema(self):
    method build_dataloaders (line 256) | def build_dataloaders(self):
    method show (line 261) | def show(self):
    method log_updates (line 266) | def log_updates(self):
    method checkpoint_path (line 270) | def checkpoint_path(self, **kwargs):
    method epoch_checkpoint_path (line 274) | def epoch_checkpoint_path(self, epoch: int, **kwargs):
    method checkpoint_path_with_name (line 278) | def checkpoint_path_with_name(self, name: str, **kwargs):
    method save_checkpoints (line 282) | def save_checkpoints(self):
    method load_from_pretrained (line 311) | def load_from_pretrained(self, name: str) -> dict:
    method load_checkpoints (line 314) | def load_checkpoints(self, load_best: bool = False, ignore_state_keys:...
    method restore (line 432) | def restore(self, load_best: bool = False, replay_metrics: bool = False,
    method commit (line 456) | def commit(self, save_checkpoints: bool = True):
    method run_epoch (line 466) | def run_epoch(self):
    method run (line 489) | def run(self):
    method should_stop_training (line 501) | def should_stop_training(self) -> bool:
    method should_run_stage (line 505) | def should_run_stage(self, stage_name) -> bool:
    method run_step (line 513) | def run_step(self, idx: int, batch: tp.Any, metrics: dict):
    method common_train_valid (line 517) | def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
    method train (line 559) | def train(self):
    method valid (line 563) | def valid(self):
    method evaluate (line 568) | def evaluate(self):
    method generate (line 573) | def generate(self):
    method run_one_stage (line 577) | def run_one_stage(self, stage_name: str):
    method get_eval_solver_from_sig (line 597) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,

FILE: audiocraft/solvers/builders.py
  class DatasetType (line 36) | class DatasetType(Enum):
  function get_solver (line 42) | def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
  function get_optim_parameter_groups (line 59) | def get_optim_parameter_groups(model: nn.Module):
  function get_optimizer (line 86) | def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]]...
  function get_lr_scheduler (line 115) | def get_lr_scheduler(optimizer: torch.optim.Optimizer,
  function get_ema (line 159) | def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp...
  function get_loss (line 180) | def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
  function get_balancer (line 194) | def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictC...
  function get_adversary (line 200) | def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
  function get_adversarial_losses (line 211) | def get_adversarial_losses(cfg) -> nn.ModuleDict:
  function get_visqol (line 244) | def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
  function get_fad (line 250) | def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMe...
  function get_kldiv (line 258) | def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
  function get_text_consistency (line 268) | def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsi...
  function get_chroma_cosine_similarity (line 278) | def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.C...
  function get_audio_datasets (line 285) | def get_audio_datasets(cfg: omegaconf.DictConfig,

FILE: audiocraft/solvers/compression.py
  class CompressionSolver (line 27) | class CompressionSolver(base.StandardSolver):
    method __init__ (line 34) | def __init__(self, cfg: omegaconf.DictConfig):
    method best_metric_name (line 55) | def best_metric_name(self) -> tp.Optional[str]:
    method build_model (line 59) | def build_model(self):
    method build_dataloaders (line 68) | def build_dataloaders(self):
    method show (line 72) | def show(self):
    method run_step (line 83) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
    method run_epoch (line 176) | def run_epoch(self):
    method evaluate (line 183) | def evaluate(self):
    method generate (line 213) | def generate(self):
    method load_from_pretrained (line 236) | def load_from_pretrained(self, name: str) -> dict:
    method model_from_checkpoint (line 269) | def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
    method wrapped_model_from_checkpoint (line 304) | def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
  function evaluate_audio_reconstruction (line 320) | def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor,...

FILE: audiocraft/solvers/diffusion.py
  class PerStageMetrics (line 25) | class PerStageMetrics:
    method __init__ (line 30) | def __init__(self, num_steps: int, num_stages: int = 4):
    method __call__ (line 34) | def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
  class DataProcess (line 53) | class DataProcess:
    method __init__ (line 67) | def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, us...
    method process_data (line 95) | def process_data(self, x, metric=False):
    method inverse_process (line 107) | def inverse_process(self, x):
  class DiffusionSolver (line 114) | class DiffusionSolver(base.StandardSolver):
    method __init__ (line 122) | def __init__(self, cfg: omegaconf.DictConfig):
    method best_metric_name (line 155) | def best_metric_name(self) -> tp.Optional[str]:
    method get_condition (line 162) | def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
    method build_model (line 168) | def build_model(self):
    method build_dataloaders (line 178) | def build_dataloaders(self):
    method show (line 182) | def show(self):
    method run_step (line 186) | def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
    method run_epoch (line 215) | def run_epoch(self):
    method evaluate (line 223) | def evaluate(self):
    method regenerate (line 253) | def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] =...
    method generate (line 262) | def generate(self):

FILE: audiocraft/solvers/musicgen.py
  class MusicGenSolver (line 30) | class MusicGenSolver(base.StandardSolver):
    method __init__ (line 37) | def __init__(self, cfg: omegaconf.DictConfig):
    method get_eval_solver_from_sig (line 64) | def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
    method get_formatter (line 100) | def get_formatter(self, stage_name: str) -> flashy.Formatter:
    method best_metric_name (line 109) | def best_metric_name(self) -> tp.Optional[str]:
    method build_model (line 112) | def build_model(self) -> None:
    method build_dataloaders (line 163) | def build_dataloaders(self) -> None:
    method show (line 167) | def show(self) -> None:
    method load_state_dict (line 174) | def load_state_dict(self, state: dict) -> None:
    method load_from_pretrained (line 185) | def load_from_pretrained(self, name: str):
    method _compute_cross_entropy (line 195) | def _compute_cross_entropy(
    method _prepare_tokens_and_attributes (line 230) | def _prepare_tokens_and_attributes(
    method run_step (line 330) | def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[Seg...
    method run_generate_step (line 404) | def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[Segm...
    method generate_audio (line 471) | def generate_audio(self) -> dict:
    method generate (line 561) | def generate(self) -> dict:
    method run_epoch (line 567) | def run_epoch(self):
    method train (line 573) | def train(self):
    method evaluate_audio_generation (line 586) | def evaluate_audio_generation(self) -> dict:
    method evaluate (line 691) | def evaluate(self) -> dict:

FILE: audiocraft/train.py
  function resolve_config_dset_paths (line 29) | def resolve_config_dset_paths(cfg):
  function get_solver (line 37) | def get_solver(cfg):
  function get_solver_from_xp (line 51) | def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, ...
  function get_solver_from_sig (line 96) | def get_solver_from_sig(sig: str, *args, **kwargs):
  function init_seed_and_system (line 104) | def init_seed_and_system(cfg):
  function main (line 125) | def main(cfg):

FILE: audiocraft/utils/autocast.py
  class TorchAutocast (line 10) | class TorchAutocast:
    method __init__ (line 21) | def __init__(self, enabled: bool, *args, **kwargs):
    method __enter__ (line 24) | def __enter__(self):
    method __exit__ (line 37) | def __exit__(self, *args, **kwargs):

FILE: audiocraft/utils/best_state.py
  class BestStateDictManager (line 21) | class BestStateDictManager(flashy.state.StateDictSource):
    method __init__ (line 36) | def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
    method _get_parameter_ids (line 43) | def _get_parameter_ids(self, state_dict):
    method _validate_no_parameter_ids_overlap (line 46) | def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
    method update (line 53) | def update(self, name: str, source: flashy.state.StateDictSource):
    method register (line 58) | def register(self, name: str, source: flashy.state.StateDictSource):
    method state_dict (line 75) | def state_dict(self) -> flashy.state.StateDict:
    method load_state_dict (line 78) | def load_state_dict(self, state: flashy.state.StateDict):

FILE: audiocraft/utils/cache.py
  function get_full_embed (line 24) | def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device...
  class EmbeddingCache (line 39) | class EmbeddingCache:
    method __init__ (line 60) | def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, t...
    method _get_cache_path (line 79) | def _get_cache_path(self, path: tp.Union[Path, str]):
    method _get_full_embed_from_cache (line 85) | def _get_full_embed_from_cache(cache: Path):
    method get_embed_from_cache (line 94) | def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> tor...
    method populate_embed_cache (line 124) | def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
  class CachedBatchWriter (line 161) | class CachedBatchWriter:
    method __init__ (line 180) | def __init__(self, cache_folder: Path):
    method start_epoch (line 185) | def start_epoch(self, epoch: int):
    method _get_zip_path (line 193) | def _get_zip_path(cache_folder: Path, epoch: int, index: int):
    method _zip_path (line 197) | def _zip_path(self):
    method save (line 201) | def save(self, *content):
  class CachedBatchLoader (line 224) | class CachedBatchLoader:
    method __init__ (line 237) | def __init__(self, cache_folder: Path, batch_size: int,
    method __len__ (line 246) | def __len__(self):
    method start_epoch (line 250) | def start_epoch(self, epoch: int):
    method _zip_path (line 255) | def _zip_path(self, index: int):
    method _load_one (line 259) | def _load_one(self, index: int):
    method __iter__ (line 296) | def __iter__(self):

FILE: audiocraft/utils/checkpoint.py
  class CheckpointSource (line 22) | class CheckpointSource(Enum):
  function checkpoint_name (line 28) | def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int...
  function is_sharded_checkpoint (line 51) | def is_sharded_checkpoint(path: Path) -> bool:
  function resolve_checkpoint_path (line 56) | def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.O...
  function load_checkpoint (line 87) | def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> ...
  function save_checkpoint (line 98) | def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bo...
  function flush_stale_checkpoints (line 104) | def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optiona...
  function check_sharded_checkpoint (line 125) | def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_pat...
  function _safe_save_checkpoint (line 142) | def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_shard...

FILE: audiocraft/utils/cluster.py
  class ClusterType (line 19) | class ClusterType(Enum):
  function _guess_cluster_type (line 27) | def _guess_cluster_type() -> ClusterType:
  function get_cluster_type (line 45) | def get_cluster_type(
  function get_slurm_parameters (line 54) | def get_slurm_parameters(

FILE: audiocraft/utils/deadlock.py
  class DeadlockDetect (line 18) | class DeadlockDetect:
    method __init__ (line 19) | def __init__(self, use: bool = False, timeout: float = 120.):
    method update (line 24) | def update(self, stage: str):
    method __enter__ (line 28) | def __enter__(self):
    method __exit__ (line 33) | def __exit__(self, exc_type, exc_val, exc_tb):
    method _detector_thread (line 38) | def _detector_thread(self):

FILE: audiocraft/utils/export.py
  function export_encodec (line 20) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Un...
  function export_pretrained_compression_model (line 36) | def export_pretrained_compression_model(pretrained_encodec: str, out_fil...
  function export_lm (line 61) | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[P...

FILE: audiocraft/utils/export_legacy.py
  function _clean_lm_cfg (line 18) | def _clean_lm_cfg(cfg: DictConfig):
  function export_encodec (line 33) | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp....
  function export_lm (line 46) | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union...

FILE: audiocraft/utils/notebook.py
  function display_audio (line 17) | def display_audio(samples: torch.Tensor, sample_rate: int):

FILE: audiocraft/utils/profiler.py
  class Profiler (line 17) | class Profiler:
    method __init__ (line 20) | def __init__(self, module: torch.nn.Module, enabled: bool = False):
    method step (line 28) | def step(self):
    method __enter__ (line 32) | def __enter__(self):
    method __exit__ (line 36) | def __exit__(self, exc_type, exc_value, exc_tb):

FILE: audiocraft/utils/samples/manager.py
  class ReferenceSample (line 42) | class ReferenceSample:
  class Sample (line 49) | class Sample:
    method __hash__ (line 59) | def __hash__(self):
    method audio (line 62) | def audio(self) -> tp.Tuple[torch.Tensor, int]:
    method audio_prompt (line 65) | def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
    method audio_reference (line 68) | def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
  class SampleManager (line 72) | class SampleManager:
    method __init__ (line 89) | def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = Fal...
    method latest_epoch (line 98) | def latest_epoch(self):
    method _load_samples (line 102) | def _load_samples(self):
    method _load_sample (line 110) | def _load_sample(json_file: Path) -> Sample:
    method _init_hash (line 126) | def _init_hash(self):
    method _get_tensor_id (line 129) | def _get_tensor_id(self, tensor: torch.Tensor) -> str:
    method _get_sample_id (line 134) | def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Ten...
    method _store_audio (line 173) | def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: ...
    method add_sample (line 196) | def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int ...
    method add_samples (line 238) | def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
    method get_samples (line 269) | def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_pr...
  function slugify (line 305) | def slugify(value: tp.Any, allow_unicode: bool = False):
  function _match_stable_samples (line 328) | def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp...
  function _match_unstable_samples (line 343) | def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> ...
  function get_samples_for_xps (line 358) | def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str,...

FILE: audiocraft/utils/ui.py
  class ToolButton (line 8) | class ToolButton(gr.Button, gr.components.IOComponent):
    method __init__ (line 11) | def __init__(self, **kwargs):
    method get_block_name (line 14) | def get_block_name(self):
  function create_refresh_button (line 18) | def create_refresh_button(refresh_component, refresh_method, refreshed_a...

FILE: audiocraft/utils/utils.py
  function model_hash (line 26) | def model_hash(model: torch.nn.Module) -> str:
  function dict_from_config (line 36) | def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
  function random_subset (line 49) | def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.ut...
  function get_loader (line 58) | def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
  function get_dataset_from_loader (line 81) | def get_dataset_from_loader(dataloader):
  function multinomial (line 89) | def multinomial(input: torch.Tensor, num_samples: int, replacement=False...
  function sample_top_k (line 109) | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
  function sample_top_p (line 126) | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
  class DummyPoolExecutor (line 145) | class DummyPoolExecutor:
    class DummyResult (line 149) | class DummyResult:
      method __init__ (line 150) | def __init__(self, func, *args, **kwargs):
      method result (line 155) | def result(self):
    method __init__ (line 158) | def __init__(self, workers, mp_context=None):
    method submit (line 161) | def submit(self, func, *args, **kwargs):
    method __enter__ (line 164) | def __enter__(self):
    method __exit__ (line 167) | def __exit__(self, exc_type, exc_value, exc_tb):
  function get_pool_executor (line 171) | def get_pool_executor(num_workers: int, mp_context=None):
  function length_to_mask (line 175) | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = No...
  function hash_trick (line 191) | def hash_trick(word: str, vocab_size: int) -> int:
  function with_rank_rng (line 204) | def with_rank_rng(base_seed: int = 1234):
  function collate (line 227) | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[to...
  function copy_state (line 251) | def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
  function swap_state (line 265) | def swap_state(model, state, **kwargs):
  function warn_once (line 275) | def warn_once(logger, msg):
  function is_jsonable (line 280) | def is_jsonable(x: tp.Any):
  function load_clap_state_dict (line 289) | def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):

FILE: demos/musicgen_app.py
  function _call_nostderr (line 39) | def _call_nostderr(*args, **kwargs):
  function interrupt (line 52) | def interrupt():
  class FileCleaner (line 57) | class FileCleaner:
    method __init__ (line 58) | def __init__(self, file_lifetime: float = 3600):
    method add (line 62) | def add(self, path: tp.Union[str, Path]):
    method _cleanup (line 66) | def _cleanup(self):
  function make_waveform (line 80) | def make_waveform(*args, **kwargs):
  function load_model (line 90) | def load_model(version='facebook/musicgen-melody'):
  function load_diffusion (line 97) | def load_diffusion():
  function _do_predictions (line 104) | def _do_predictions(texts, melodies, duration, progress=False, **gen_kwa...
  function predict_batched (line 154) | def predict_batched(texts, melodies):
  function predict_full (line 162) | def predict_full(model, decoder, text, melody, duration, topk, topp, tem...
  function toggle_audio_src (line 195) | def toggle_audio_src(choice):
  function toggle_diffusion (line 202) | def toggle_diffusion(choice):
  function ui_full (line 209) | def ui_full(launch_kwargs):
  function ui_batched (line 338) | def ui_batched(launch_kwargs):

FILE: scripts/mos.py
  function normalize_path (line 43) | def normalize_path(path: Path):
  function get_full_path (line 51) | def get_full_path(normalized_path: Path):
  function get_signature (line 57) | def get_signature(xps: tp.List[str]):
  function ensure_logged (line 63) | def ensure_logged(func):
  function login (line 76) | def login():
  function index (line 98) | def index():
  function survey (line 135) | def survey(signature):
  function audio (line 236) | def audio(path: str):
  function mean (line 242) | def mean(x):
  function std (line 246) | def std(x):
  function results (line 253) | def results(signature):

FILE: scripts/resample_dataset.py
  function read_txt_files (line 22) | def read_txt_files(path: tp.Union[str, Path]):
  function read_egs_files (line 31) | def read_egs_files(path: tp.Union[str, Path]):
  function process_dataset (line 45) | def process_dataset(args, n_shards: int, node_index: int, task_index: tp...

FILE: tests/adversarial/test_discriminators.py
  class TestMultiPeriodDiscriminator (line 18) | class TestMultiPeriodDiscriminator:
    method test_mpd_discriminator (line 20) | def test_mpd_discriminator(self):
  class TestMultiScaleDiscriminator (line 33) | class TestMultiScaleDiscriminator:
    method test_msd_discriminator (line 35) | def test_msd_discriminator(self):
  class TestMultiScaleStftDiscriminator (line 49) | class TestMultiScaleStftDiscriminator:
    method test_msstftd_discriminator (line 51) | def test_msstftd_discriminator(self):

FILE: tests/adversarial/test_losses.py
  class TestAdversarialLoss (line 22) | class TestAdversarialLoss:
    method test_adversarial_single_multidiscriminator (line 24) | def test_adversarial_single_multidiscriminator(self):
    method test_adversarial_feat_loss (line 45) | def test_adversarial_feat_loss(self):
  class TestGeneratorAdversarialLoss (line 65) | class TestGeneratorAdversarialLoss:
    method test_hinge_generator_adv_loss (line 67) | def test_hinge_generator_adv_loss(self):
    method test_mse_generator_adv_loss (line 76) | def test_mse_generator_adv_loss(self):
  class TestDiscriminatorAdversarialLoss (line 88) | class TestDiscriminatorAdversarialLoss:
    method _disc_loss (line 90) | def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.T...
    method test_hinge_discriminator_adv_loss (line 97) | def test_hinge_discriminator_adv_loss(self):
    method test_mse_discriminator_adv_loss (line 105) | def test_mse_discriminator_adv_loss(self):
  class TestFeatureMatchingLoss (line 115) | class TestFeatureMatchingLoss:
    method test_features_matching_loss_base (line 117) | def test_features_matching_loss_base(self):
    method test_features_matching_loss_raises_exception (line 126) | def test_features_matching_loss_raises_exception(self):
    method test_features_matching_loss_output (line 141) | def test_features_matching_loss_output(self):

FILE: tests/common_utils/temp_utils.py
  class TempDirMixin (line 11) | class TempDirMixin:
    method get_base_temp_dir (line 18) | def get_base_temp_dir(cls):
    method tearDownClass (line 29) | def tearDownClass(cls):
    method id (line 43) | def id(self):
    method get_temp_path (line 46) | def get_temp_path(self, *paths):
    method get_temp_dir (line 52) | def get_temp_dir(self, *paths):

FILE: tests/common_utils/wav_utils.py
  function get_white_noise (line 14) | def get_white_noise(chs: int = 1, num_frames: int = 1):
  function get_batch_white_noise (line 19) | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
  function save_wav (line 24) | def save_wav(path: str, wav: torch.Tensor, sample_rate: int):

FILE: tests/data/test_audio.py
  class TestInfo (line 19) | class TestInfo(TempDirMixin):
    method test_info_mp3 (line 21) | def test_info_mp3(self):
    method _test_info_format (line 34) | def _test_info_format(self, ext: str):
    method test_info_wav (line 48) | def test_info_wav(self):
    method test_info_flac (line 51) | def test_info_flac(self):
    method test_info_ogg (line 54) | def test_info_ogg(self):
    method test_info_m4a (line 57) | def test_info_m4a(self):
  class TestRead (line 63) | class TestRead(TempDirMixin):
    method test_read_full_wav (line 65) | def test_read_full_wav(self):
    method test_read_partial_wav (line 80) | def test_read_partial_wav(self):
    method test_read_seek_time_wav (line 97) | def test_read_seek_time_wav(self):
    method test_read_seek_time_wav_padded (line 116) | def test_read_seek_time_wav_padded(self):
  class TestAvRead (line 139) | class TestAvRead(TempDirMixin):
    method test_avread_seek_base (line 141) | def test_avread_seek_base(self):
    method test_avread_seek_partial (line 159) | def test_avread_seek_partial(self):
    method test_avread_seek_outofbound (line 178) | def test_avread_seek_outofbound(self):
    method test_avread_seek_edge (line 193) | def test_avread_seek_edge(self):
  class TestAudioWrite (line 212) | class TestAudioWrite(TempDirMixin):
    method test_audio_write_wav (line 214) | def test_audio_write_wav(self):

FILE: tests/data/test_audio_dataset.py
  class TestAudioMeta (line 31) | class TestAudioMeta(TempDirMixin):
    method test_get_audio_meta (line 33) | def test_get_audio_meta(self):
    method test_save_audio_meta (line 49) | def test_save_audio_meta(self):
    method test_load_audio_meta (line 65) | def test_load_audio_meta(self):
  class TestAudioDataset (line 90) | class TestAudioDataset(TempDirMixin):
    method _create_audio_files (line 92) | def _create_audio_files(self,
    method _create_audio_dataset (line 114) | def _create_audio_dataset(self,
    method test_dataset_full (line 135) | def test_dataset_full(self):
    method test_dataset_segment (line 152) | def test_dataset_segment(self):
    method test_dataset_equal_audio_and_segment_durations (line 170) | def test_dataset_equal_audio_and_segment_durations(self):
    method test_dataset_samples (line 192) | def test_dataset_samples(self):
    method test_dataset_return_info (line 218) | def test_dataset_return_info(self):
    method test_dataset_return_info_no_segment_duration (line 240) | def test_dataset_return_info_no_segment_duration(self):
    method test_dataset_collate_fn (line 260) | def test_dataset_collate_fn(self):
    method test_dataset_with_meta_collate_fn (line 280) | def test_dataset_with_meta_collate_fn(self, segment_duration):
    method test_sample_with_weight (line 308) | def test_sample_with_weight(self, segment_duration, sample_on_weight, ...
    method test_meta_duration_filter_all (line 333) | def test_meta_duration_filter_all(self):
    method test_meta_duration_filter_long (line 345) | def test_meta_duration_filter_long(self):

FILE: tests/data/test_audio_utils.py
  class TestConvertAudioChannels (line 20) | class TestConvertAudioChannels:
    method test_convert_audio_channels_downmix (line 22) | def test_convert_audio_channels_downmix(self):
    method test_convert_audio_channels_nochange (line 28) | def test_convert_audio_channels_nochange(self):
    method test_convert_audio_channels_upmix (line 34) | def test_convert_audio_channels_upmix(self):
    method test_convert_audio_channels_upmix_error (line 40) | def test_convert_audio_channels_upmix_error(self):
  class TestConvertAudio (line 47) | class TestConvertAudio:
    method test_convert_audio_channels_downmix (line 49) | def test_convert_audio_channels_downmix(self):
    method test_convert_audio_channels_upmix (line 56) | def test_convert_audio_channels_upmix(self):
    method test_convert_audio_upsample (line 63) | def test_convert_audio_upsample(self):
    method test_convert_audio_resample (line 72) | def test_convert_audio_resample(self):
  class TestNormalizeAudio (line 82) | class TestNormalizeAudio:
    method test_clip_wav (line 84) | def test_clip_wav(self):
    method test_normalize_audio_clip (line 91) | def test_normalize_audio_clip(self):
    method test_normalize_audio_rms (line 98) | def test_normalize_audio_rms(self):
    method test_normalize_audio_peak (line 105) | def test_normalize_audio_peak(self):

FILE: tests/losses/test_losses.py
  function test_mel_l1_loss (line 20) | def test_mel_l1_loss():
  function test_msspec_loss (line 34) | def test_msspec_loss():
  function test_mrstft_loss (line 48) | def test_mrstft_loss():
  function test_sisnr_loss (line 59) | def test_sisnr_loss():
  function test_stft_loss (line 70) | def test_stft_loss():

FILE: tests/models/test_audiogen.py
  class TestAudioGenModel (line 13) | class TestAudioGenModel:
    method get_audiogen (line 14) | def get_audiogen(self):
    method test_base (line 19) | def test_base(self):
    method test_generate_continuation (line 25) | def test_generate_continuation(self):
    method test_generate (line 41) | def test_generate(self):
    method test_generate_long (line 47) | def test_generate_long(self):

FILE: tests/models/test_encodec_model.py
  class TestEncodecModel (line 17) | class TestEncodecModel:
    method _create_encodec_model (line 19) | def _create_encodec_model(self,
    method test_model (line 37) | def test_model(self):
    method test_model_renorm (line 48) | def test_model_renorm(self):

FILE: tests/models/test_multibanddiffusion.py
  class TestMBD (line 18) | class TestMBD:
    method _create_mbd (line 20) | def _create_mbd(self,
    method test_model (line 43) | def test_model(self):

FILE: tests/models/test_musicgen.py
  class TestMusicGenModel (line 13) | class TestMusicGenModel:
    method get_musicgen (line 14) | def get_musicgen(self):
    method test_base (line 19) | def test_base(self):
    method test_generate_unconditional (line 25) | def test_generate_unconditional(self):
    method test_generate_continuation (line 30) | def test_generate_continuation(self):
    method test_generate (line 46) | def test_generate(self):
    method test_generate_long (line 52) | def test_generate_long(self):

FILE: tests/modules/test_activations.py
  class TestActivations (line 13) | class TestActivations:
    method test_custom_glu_calculation (line 14) | def test_custom_glu_calculation(self):

FILE: tests/modules/test_codebooks_patterns.py
  class TestParallelPatternProvider (line 18) | class TestParallelPatternProvider:
    method test_get_pattern (line 22) | def test_get_pattern(self, n_q: int, timesteps: int):
    method test_pattern_content (line 30) | def test_pattern_content(self, n_q: int, timesteps: int):
    method test_pattern_max_delay (line 40) | def test_pattern_max_delay(self, n_q: int, timesteps: int):
  class TestDelayedPatternProvider (line 47) | class TestDelayedPatternProvider:
    method test_get_pattern (line 51) | def test_get_pattern(self, n_q: int, timesteps: int):
    method test_pattern_content (line 65) | def test_pattern_content(self, n_q: int, timesteps: int):
    method test_pattern_max_delay (line 75) | def test_pattern_max_delay(self, timesteps: int, delay: list):
  class TestUnrolledPatternProvider (line 82) | class TestUnrolledPatternProvider:
    method test_get_pattern (line 87) | def test_get_pattern(self, timesteps: int, flattening: list, delays: l...
    method test_pattern_max_delay (line 97) | def test_pattern_max_delay(self, timesteps: int, flattening: list, del...
  class TestPattern (line 105) | class TestPattern:
    method ref_build_pattern_sequence (line 107) | def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern...
    method ref_revert_pattern_sequence (line 121) | def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Patter...
    method ref_revert_pattern_logits (line 134) | def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern,...
    method _get_pattern_providers (line 149) | def _get_pattern_providers(self, n_q: int):
    method test_build_pattern_sequence (line 173) | def test_build_pattern_sequence(self, n_q: int, timesteps: int):
    method test_revert_pattern_sequence (line 205) | def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
    method test_revert_pattern_logits (line 228) | def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: i...

FILE: tests/modules/test_conv.py
  function test_get_extra_padding_for_conv1d (line 25) | def test_get_extra_padding_for_conv1d():
  function test_pad1d_zeros (line 30) | def test_pad1d_zeros():
  function test_pad1d_reflect (line 52) | def test_pad1d_reflect():
  function test_unpad1d (line 74) | def test_unpad1d():
  class TestNormConv1d (line 96) | class TestNormConv1d:
    method test_norm_conv1d_modules (line 98) | def test_norm_conv1d_modules(self):
  class TestNormConvTranspose1d (line 123) | class TestNormConvTranspose1d:
    method test_normalizations (line 125) | def test_normalizations(self):
  class TestStreamableConv1d (line 151) | class TestStreamableConv1d:
    method get_streamable_conv1d_output_length (line 153) | def get_streamable_conv1d_output_length(self, length, kernel_size, str...
    method test_streamable_conv1d (line 160) | def test_streamable_conv1d(self):
  class TestStreamableConvTranspose1d (line 176) | class TestStreamableConvTranspose1d:
    method get_streamable_convtr1d_output_length (line 178) | def get_streamable_convtr1d_output_length(self, length, kernel_size, s...
    method test_streamable_convtr1d (line 182) | def test_streamable_convtr1d(self):

FILE: tests/modules/test_lstm.py
  class TestStreamableLSTM (line 13) | class TestStreamableLSTM:
    method test_lstm (line 15) | def test_lstm(self):
    method test_lstm_skip (line 25) | def test_lstm_skip(self):

FILE: tests/modules/test_rope.py
  function test_rope (line 13) | def test_rope():
  function test_rope_io_dtypes (line 26) | def test_rope_io_dtypes():
  function test_transformer_with_rope (line 50) | def test_transformer_with_rope():
  function test_rope_streaming (line 66) | def test_rope_streaming():
  function test_rope_streaming_past_context (line 94) | def test_rope_streaming_past_context():
  function test_rope_memory_efficient (line 124) | def test_rope_memory_efficient():
  function test_rope_with_xpos (line 145) | def test_rope_with_xpos():
  function test_positional_scale (line 158) | def test_positional_scale():

FILE: tests/modules/test_seanet.py
  class TestSEANetModel (line 16) | class TestSEANetModel:
    method test_base (line 18) | def test_base(self):
    method test_causal (line 28) | def test_causal(self):
    method test_conv_skip_connection (line 38) | def test_conv_skip_connection(self):
    method test_seanet_encoder_decoder_final_act (line 48) | def test_seanet_encoder_decoder_final_act(self):
    method _check_encoder_blocks_norm (line 58) | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable...
    method test_encoder_disable_norm (line 70) | def test_encoder_disable_norm(self):
    method _check_decoder_blocks_norm (line 79) | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable...
    method test_decoder_disable_norm (line 94) | def test_decoder_disable_norm(self):
    method test_disable_norm_raises_exception (line 103) | def test_disable_norm_raises_exception(self):

FILE: tests/modules/test_transformer.py
  function test_transformer_causal_streaming (line 16) | def test_transformer_causal_streaming():
  function test_transformer_vs_pytorch (line 52) | def test_transformer_vs_pytorch():
  function test_streaming_api (line 71) | def test_streaming_api():
  function test_memory_efficient (line 88) | def test_memory_efficient():
  function test_attention_as_float32 (line 108) | def test_attention_as_float32():
  function test_streaming_memory_efficient (line 134) | def test_streaming_memory_efficient():
  function test_cross_attention (line 164) | def test_cross_attention():
  function test_cross_attention_compat (line 192) | def test_cross_attention_compat():
  function test_repeat_kv (line 224) | def test_repeat_kv():
  function test_qk_layer_norm (line 241) | def test_qk_layer_norm():

FILE: tests/quantization/test_vq.py
  class TestResidualVectorQuantizer (line 12) | class TestResidualVectorQuantizer:
    method test_rvq (line 14) | def test_rvq(self):
Condensed preview — 227 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,182K chars).
[
  {
    "path": ".github/actions/audiocraft_build/action.yml",
    "chars": 746,
    "preview": "name: audiocraft_build\ndescription: 'Build audiocraft env.'\nruns:\n  using: \"composite\"\n  steps:\n  - uses: actions/setup-"
  },
  {
    "path": ".gitignore",
    "chars": 629,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# macOS dir files\n.DS_Sto"
  },
  {
    "path": "CHANGELOG.md",
    "chars": 920,
    "preview": "# Changelog\n\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Change"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 3535,
    "preview": "# Code of Conduct\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1377,
    "preview": "# Contributing to AudioCraft\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pull"
  },
  {
    "path": "Dockerfile",
    "chars": 1000,
    "preview": "FROM nvidia/cuda:11.8.0-base-ubuntu22.04\n\nENV DEBIAN_FRONTEND=noninteractive \\\n    PYTHONUNBUFFERED=1 \\\n    PYTHONIOENCO"
  },
  {
    "path": "LICENSE",
    "chars": 1088,
    "preview": "MIT License\n\nCopyright (c) Meta Platforms, Inc. and affiliates.\n\nPermission is hereby granted, free of charge, to any pe"
  },
  {
    "path": "LICENSE_weights",
    "chars": 19329,
    "preview": "Attribution-NonCommercial 4.0 International\n\n=======================================================================\n\nCr"
  },
  {
    "path": "MANIFEST.in",
    "chars": 188,
    "preview": "include Makefile\ninclude LICENSE\ninclude LICENSE_weights\ninclude *.md\ninclude *.ini\ninclude requirements.txt\ninclude aud"
  },
  {
    "path": "Makefile",
    "chars": 1464,
    "preview": "INTEG=AUDIOCRAFT_DORA_DIR=\"/tmp/magma_$(USER)\" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epo"
  },
  {
    "path": "README.md",
    "chars": 5157,
    "preview": "# AudioCraft Plus\n![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)\n![li"
  },
  {
    "path": "app.py",
    "chars": 95768,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "audiocraft/__init__.py",
    "chars": 1247,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/__init__.py",
    "chars": 570,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/discriminators/__init__.py",
    "chars": 346,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/discriminators/base.py",
    "chars": 894,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/discriminators/mpd.py",
    "chars": 4176,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/discriminators/msd.py",
    "chars": 5926,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/discriminators/msstftd.py",
    "chars": 6331,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/adversarial/losses.py",
    "chars": 9126,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/__init__.py",
    "chars": 396,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/audio.py",
    "chars": 9042,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/audio_dataset.py",
    "chars": 25464,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/audio_utils.py",
    "chars": 7789,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/info_audio_dataset.py",
    "chars": 3902,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/music_dataset.py",
    "chars": 11575,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/sound_dataset.py",
    "chars": 13381,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/data/zip.py",
    "chars": 2202,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/environment.py",
    "chars": 6741,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/__init__.py",
    "chars": 216,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/_base_explorers.py",
    "chars": 2639,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/audiogen/__init__.py",
    "chars": 220,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/audiogen/audiogen_base_16khz.py",
    "chars": 776,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py",
    "chars": 2483,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/compression/__init__.py",
    "chars": 219,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/compression/_explorers.py",
    "chars": 1601,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/compression/debug.py",
    "chars": 1117,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/compression/encodec_audiogen_16khz.py",
    "chars": 1100,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/compression/encodec_base_24khz.py",
    "chars": 956,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/compression/encodec_musicgen_32khz.py",
    "chars": 1262,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/diffusion/4_bands_base_32khz.py",
    "chars": 1073,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/diffusion/__init__.py",
    "chars": 221,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/diffusion/_explorers.py",
    "chars": 2066,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/__init__.py",
    "chars": 220,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/_explorers.py",
    "chars": 3092,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/musicgen_base_32khz.py",
    "chars": 1413,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/musicgen_base_cached_32khz.py",
    "chars": 2311,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/musicgen_clapemb_32khz.py",
    "chars": 1193,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/musicgen_melody_32khz.py",
    "chars": 2251,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py",
    "chars": 3880,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/losses/__init__.py",
    "chars": 585,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/losses/balancer.py",
    "chars": 6612,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/losses/sisnr.py",
    "chars": 2914,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/losses/specloss.py",
    "chars": 6531,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/losses/stftloss.py",
    "chars": 8202,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/__init__.py",
    "chars": 592,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/chroma_cosinesim.py",
    "chars": 3674,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/clap_consistency.py",
    "chars": 4525,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/fad.py",
    "chars": 17721,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/kld.py",
    "chars": 10211,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/rvm.py",
    "chars": 6107,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/metrics/visqol.py",
    "chars": 9694,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/__init__.py",
    "chars": 605,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/audiogen.py",
    "chars": 12774,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/builders.py",
    "chars": 9860,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/encodec.py",
    "chars": 13546,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/lm.py",
    "chars": 27027,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/loaders.py",
    "chars": 5275,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/multibanddiffusion.py",
    "chars": 8855,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/musicgen.py",
    "chars": 19960,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/models/unet.py",
    "chars": 8340,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/__init__.py",
    "chars": 586,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/activations.py",
    "chars": 3266,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/chroma.py",
    "chars": 3023,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/codebooks_patterns.py",
    "chars": 27624,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/conditioners.py",
    "chars": 64621,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/conv.py",
    "chars": 10496,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/diffusion_schedule.py",
    "chars": 12018,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/lstm.py",
    "chars": 759,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/rope.py",
    "chars": 5413,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/seanet.py",
    "chars": 13868,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/streaming.py",
    "chars": 4494,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/modules/transformer.py",
    "chars": 36868,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/__init__.py",
    "chars": 638,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/cosine_lr_scheduler.py",
    "chars": 1730,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/dadam.py",
    "chars": 9002,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/ema.py",
    "chars": 3196,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/fsdp.py",
    "chars": 7284,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/inverse_sqrt_lr_scheduler.py",
    "chars": 1390,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/linear_warmup_lr_scheduler.py",
    "chars": 1272,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/optim/polynomial_decay_lr_scheduler.py",
    "chars": 2012,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/py.typed",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "audiocraft/quantization/__init__.py",
    "chars": 329,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/quantization/base.py",
    "chars": 3314,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/quantization/core_vq.py",
    "chars": 14357,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/quantization/vq.py",
    "chars": 4649,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/__init__.py",
    "chars": 574,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/audiogen.py",
    "chars": 655,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/base.py",
    "chars": 31355,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/builders.py",
    "chars": 13937,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/compression.py",
    "chars": 14774,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/diffusion.py",
    "chars": 11336,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/solvers/musicgen.py",
    "chars": 34733,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/train.py",
    "chars": 6453,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/__init__.py",
    "chars": 215,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/autocast.py",
    "chars": 1377,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/best_state.py",
    "chars": 3694,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/cache.py",
    "chars": 14289,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/checkpoint.py",
    "chars": 6129,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/cluster.py",
    "chars": 2044,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/deadlock.py",
    "chars": 1710,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/export.py",
    "chars": 2677,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/export_legacy.py",
    "chars": 1906,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/notebook.py",
    "chars": 885,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/profiler.py",
    "chars": 1209,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/samples/__init__.py",
    "chars": 198,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/samples/manager.py",
    "chars": 19385,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "audiocraft/utils/ui.py",
    "chars": 939,
    "preview": "from pathlib import Path\n\nimport gradio as gr\nimport torch\n\nrefresh_symbol = '\\U0001f504'  # 🔄\n\nclass ToolButton(gr.Butt"
  },
  {
    "path": "audiocraft/utils/utils.py",
    "chars": 10599,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "config/conditioner/chroma2music.yaml",
    "chars": 817,
    "preview": "# @package __global__\n\nclassifier_free_guidance:\n  training_dropout: 0.2\n  inference_coef: 3.0\n\nattribute_dropout:\n  arg"
  },
  {
    "path": "config/conditioner/clapemb2music.yaml",
    "chars": 922,
    "preview": "# @package __global__\n\nclassifier_free_guidance:\n  training_dropout: 0.3\n  inference_coef: 3.0\n\nattribute_dropout:\n  tex"
  },
  {
    "path": "config/conditioner/none.yaml",
    "chars": 239,
    "preview": "# @package __global__\n\n# No conditioning\n\nclassifier_free_guidance:\n  training_dropout: 0\n  inference_coef: 1\n\nattribute"
  },
  {
    "path": "config/conditioner/text2music.yaml",
    "chars": 496,
    "preview": "# @package __global__\n\nclassifier_free_guidance:\n  training_dropout: 0.3\n  inference_coef: 3.0\n\nattribute_dropout: {}\n\nf"
  },
  {
    "path": "config/conditioner/text2sound.yaml",
    "chars": 411,
    "preview": "# @package __global__\n\nclassifier_free_guidance:\n  training_dropout: 0.1\n  inference_coef: 3.0\n\nattribute_dropout: {}\n\nf"
  },
  {
    "path": "config/config.yaml",
    "chars": 2733,
    "preview": "# WARNING: This is the base configuration file shared across ALL solvers in AudioCraft\n# Please don't update this file d"
  },
  {
    "path": "config/dset/audio/audiocaps_16khz.yaml",
    "chars": 281,
    "preview": "# @package __global__\n\n# AudioCaps dataset\ndatasource:\n  max_sample_rate: 16000\n  max_channels: 1\n\n  train: null  # only"
  },
  {
    "path": "config/dset/audio/default.yaml",
    "chars": 138,
    "preview": "# @package __global__\n\ndatasource:\n  max_sample_rate: ???\n  max_channels: ???\n\n  train: ???\n  valid: ???\n  evaluate: ???"
  },
  {
    "path": "config/dset/audio/example.yaml",
    "chars": 169,
    "preview": "# @package __global__\n\ndatasource:\n  max_sample_rate: 44100\n  max_channels: 2\n\n  train: egs/example\n  valid: egs/example"
  },
  {
    "path": "config/dset/audio/musiccaps_32khz.yaml",
    "chars": 358,
    "preview": "# @package __global__\n\n# total samples obtained from MusicCaps = 5469\n# (out of 5521 due to AudioSet corrupted samples)\n"
  },
  {
    "path": "config/dset/default.yaml",
    "chars": 300,
    "preview": "# @package __global__\n\n# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft\n# Please don"
  },
  {
    "path": "config/dset/internal/music_10k_32khz.yaml",
    "chars": 338,
    "preview": "# @package __global__\n\n# high quality music dataset with no artist overlap between splits\ndatasource:\n  max_sample_rate:"
  },
  {
    "path": "config/dset/internal/music_400k_32khz.yaml",
    "chars": 275,
    "preview": "# @package __global__\n\ndatasource:\n  max_sample_rate: 32000\n  max_channels: 1\n\n  train: egs/music/music_400k_32khz/train"
  },
  {
    "path": "config/dset/internal/sounds_16khz.yaml",
    "chars": 344,
    "preview": "# @package __global__\n\n# environmental sounds dataset compiling all datasets\n# with applied filters on tags\ndatasource:\n"
  },
  {
    "path": "config/model/encodec/default.yaml",
    "chars": 1112,
    "preview": "# @package __global__\n\ncompression_model: encodec\n\nencodec:\n  autoencoder: seanet\n  quantizer: rvq\n  sample_rate: ${samp"
  },
  {
    "path": "config/model/encodec/encodec_base_causal.yaml",
    "chars": 112,
    "preview": "# @package __global__\n\ndefaults:\n  - encodec/default\n\nencodec:\n  causal: true\n\nrvq:\n  n_q: 32\n  q_dropout: true\n"
  },
  {
    "path": "config/model/encodec/encodec_large_nq4_s320.yaml",
    "chars": 161,
    "preview": "# @package __global__\n\ndefaults:\n  - encodec/default\n\nseanet:\n  # default ratios are [8, 5, 4, 2]\n  n_filters: 64\n\nrvq:\n"
  },
  {
    "path": "config/model/encodec/encodec_large_nq4_s640.yaml",
    "chars": 148,
    "preview": "# @package __global__\n\ndefaults:\n  - encodec/default\n\nseanet:\n  ratios: [8, 5, 4, 4]\n  n_filters: 64\n\nrvq:\n  bins: 2048\n"
  },
  {
    "path": "config/model/lm/audiogen_lm.yaml",
    "chars": 731,
    "preview": "# @package __global__\n\ndefaults:\n  - lm/default\n  - override /conditioner: text2sound\n  - override /model/lm/model_scale"
  },
  {
    "path": "config/model/lm/default.yaml",
    "chars": 1992,
    "preview": "# @package __global__\ndefaults:\n  - _self_\n  - /model/lm/model_scale: base # prefer this group to set model scale instea"
  },
  {
    "path": "config/model/lm/model_scale/base.yaml",
    "chars": 102,
    "preview": "# @package __global__\n\n# overrides nothing because default is already transformer base (~ 60M params)\n"
  },
  {
    "path": "config/model/lm/model_scale/large.yaml",
    "chars": 126,
    "preview": "# @package _global_\n\n# gpt2 inspired, even bigger (~3.3B params)\ntransformer_lm:\n  dim: 2048\n  num_heads: 32\n  num_layer"
  },
  {
    "path": "config/model/lm/model_scale/medium.yaml",
    "chars": 109,
    "preview": "# @package _global_\n\n# gpt2 like (~1.5B params)\ntransformer_lm:\n  dim: 1536\n  num_heads: 24\n  num_layers: 48\n"
  },
  {
    "path": "config/model/lm/model_scale/small.yaml",
    "chars": 97,
    "preview": "# @package _global_\n\n# 300M Param.\n\ntransformer_lm:\n  dim: 1024\n  num_heads: 16\n  num_layers: 24\n"
  },
  {
    "path": "config/model/lm/model_scale/xsmall.yaml",
    "chars": 181,
    "preview": "# @package _global_\n# just used for debugging or when we just want to populate the cache\n# and do not care about trainin"
  },
  {
    "path": "config/model/lm/musicgen_lm.yaml",
    "chars": 731,
    "preview": "# @package __global__\n\ndefaults:\n  - lm/default\n  - override /conditioner: text2music\n  - override /model/lm/model_scale"
  },
  {
    "path": "config/model/none.yaml",
    "chars": 157,
    "preview": "# @package __global__\n\n# This file exist so that model is recognized as a config group\n# by Hydra, and Dora. A bit weird"
  },
  {
    "path": "config/model/score/basic.yaml",
    "chars": 269,
    "preview": "# @package _global_\n\ndiffusion_unet:\n  hidden: 48\n  depth: 4\n  res_blocks: 1\n  norm_groups: 4\n  kernel: 8\n  stride: 4\n  "
  },
  {
    "path": "config/solver/audiogen/audiogen_base_16khz.yaml",
    "chars": 1772,
    "preview": "# @package __global__\n\n# This is the training loop solver\n# for the base AudioGen model (text-to-sound)\n# on monophonic "
  },
  {
    "path": "config/solver/audiogen/debug.yaml",
    "chars": 862,
    "preview": "# @package __global__\n\n# This is a minimal debugging configuration\n# for MusicGen training solver\ndefaults:\n  - audiogen"
  },
  {
    "path": "config/solver/audiogen/default.yaml",
    "chars": 769,
    "preview": "# @package __global__\n\ndefaults:\n  - /solver/musicgen/default\n  - _self_\n  - /solver/audiogen/evaluation: none\n  - overr"
  },
  {
    "path": "config/solver/audiogen/evaluation/none.yaml",
    "chars": 67,
    "preview": "# @package __global__\n\ndataset:\n  evaluate:\n    num_samples: 10000\n"
  },
  {
    "path": "config/solver/audiogen/evaluation/objective_eval.yaml",
    "chars": 725,
    "preview": "# @package __global__\n\n# Setup for execute only on audiocaps for audio generation\n# evaluation with objective metrics\n# "
  },
  {
    "path": "config/solver/compression/debug.yaml",
    "chars": 812,
    "preview": "# @package __global__\n\ndefaults:\n  - compression/default\n  - /model: encodec/encodec_base_causal\n  - override /dset: aud"
  },
  {
    "path": "config/solver/compression/default.yaml",
    "chars": 2955,
    "preview": "# @package __global__\n\ndefaults:\n  - ../default\n  - override /dset: audio/default\n  - _self_\n\nsolver: compression\nsample"
  },
  {
    "path": "config/solver/compression/encodec_audiogen_16khz.yaml",
    "chars": 177,
    "preview": "# @package __global__\n\ndefaults:\n  - compression/default\n  - /model: encodec/encodec_large_nq4_s320\n  - override /dset: "
  },
  {
    "path": "config/solver/compression/encodec_base_24khz.yaml",
    "chars": 174,
    "preview": "# @package __global__\n\ndefaults:\n  - compression/default\n  - /model: encodec/encodec_base_causal\n  - override /dset: aud"
  },
  {
    "path": "config/solver/compression/encodec_musicgen_32khz.yaml",
    "chars": 177,
    "preview": "# @package __global__\n\ndefaults:\n  - compression/default\n  - /model: encodec/encodec_large_nq4_s640\n  - override /dset: "
  },
  {
    "path": "config/solver/default.yaml",
    "chars": 2534,
    "preview": "# @package __global__\n\n# WARNING: This is a base configuration file shared across ALL solvers in AudioCraft\n# Please don"
  },
  {
    "path": "config/solver/diffusion/debug.yaml",
    "chars": 1677,
    "preview": "# @package __global__\n\ndefaults:\n  - /solver/default\n  - /model: score/basic\n  - override /dset: audio/default\n  - _self"
  },
  {
    "path": "config/solver/diffusion/default.yaml",
    "chars": 1690,
    "preview": "# @package __global__\n\ndefaults:\n  - /solver/default\n  - /model: score/basic\n  - override /dset: audio/default\n  - _self"
  },
  {
    "path": "config/solver/diffusion/encodec_24khz.yaml",
    "chars": 197,
    "preview": "# @package __global__\n\ndefaults:\n  - diffusion/default\n  - _self_\n\n\nsample_rate: 24000\nchannels: 1\ncompression_model_che"
  },
  {
    "path": "config/solver/musicgen/debug.yaml",
    "chars": 931,
    "preview": "# @package __global__\n\n# This is a minimal debugging configuration\n# for MusicGen training solver\ndefaults:\n  - musicgen"
  },
  {
    "path": "config/solver/musicgen/default.yaml",
    "chars": 2343,
    "preview": "# @package __global__\n\ndefaults:\n  - /solver/default\n  - /conditioner: none\n  - _self_\n  - /solver/musicgen/evaluation: "
  },
  {
    "path": "config/solver/musicgen/evaluation/none.yaml",
    "chars": 67,
    "preview": "# @package __global__\n\ndataset:\n  evaluate:\n    num_samples: 10000\n"
  },
  {
    "path": "config/solver/musicgen/evaluation/objective_eval.yaml",
    "chars": 619,
    "preview": "# @package __global__\n\n# Setup for execute only on musiccaps for audio generation\n# evaluation with objective metrics\n# "
  },
  {
    "path": "config/solver/musicgen/musicgen_base_32khz.yaml",
    "chars": 1125,
    "preview": "# @package __global__\n\n# This is the training loop solver\n# for the base MusicGen model (text-to-music)\n# on monophonic "
  },
  {
    "path": "config/solver/musicgen/musicgen_melody_32khz.yaml",
    "chars": 1174,
    "preview": "# @package __global__\n\n# This is the training loop solver\n# for the melody MusicGen model (text+chroma to music)\n# on mo"
  },
  {
    "path": "config/teams/default.yaml",
    "chars": 324,
    "preview": "default:\n  dora_dir: /tmp/audiocraft_${oc.env:USER}\n  partitions:\n    global: debug\n    team: debug\n  reference_dir: /tm"
  },
  {
    "path": "config/teams/labs.yaml",
    "chars": 836,
    "preview": "aws:\n  dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs\n  partitions:\n    global: learnlab\n "
  },
  {
    "path": "dataset/example/electro_1.json",
    "chars": 322,
    "preview": "{\"key\": \"\", \"artist\": \"Voyager I\", \"sample_rate\": 48000, \"file_extension\": \"mp3\", \"description\": \"A cool song from Voyag"
  },
  {
    "path": "dataset/example/electro_2.json",
    "chars": 300,
    "preview": "{\"key\": \"\", \"artist\": \"Voyager I\", \"sample_rate\": 44100, \"file_extension\": \"mp3\", \"description\": \"This is an electronic "
  },
  {
    "path": "demos/audiogen_demo.ipynb",
    "chars": 5463,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# AudioGen\\n\",\n    \"Welcome to Audi"
  },
  {
    "path": "demos/musicgen_app.py",
    "chars": 18223,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "demos/musicgen_demo.ipynb",
    "chars": 7855,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# MusicGen\\n\",\n    \"Welcome to Musi"
  },
  {
    "path": "dockerignore",
    "chars": 6,
    "preview": "cache/"
  },
  {
    "path": "docs/AUDIOGEN.md",
    "chars": 6553,
    "preview": "# AudioGen: Textually-guided audio generation\n\nAudioCraft provides the code and a model re-implementing AudioGen, a [tex"
  },
  {
    "path": "docs/CONDITIONING.md",
    "chars": 7369,
    "preview": "# AudioCraft conditioning modules\n\nAudioCraft provides a\n[modular implementation of conditioning modules](../audiocraft/"
  },
  {
    "path": "docs/DATASETS.md",
    "chars": 3073,
    "preview": "# AudioCraft datasets\n\nOur dataset manifest files consist in 1-json-per-line files, potentially gzipped,\nas `data.jsons`"
  },
  {
    "path": "docs/ENCODEC.md",
    "chars": 6803,
    "preview": "# EnCodec: High Fidelity Neural Audio Compression\n\nAudioCraft provides the training code for EnCodec, a state-of-the-art"
  },
  {
    "path": "docs/MBD.md",
    "chars": 5101,
    "preview": "# MultiBand Diffusion\n\nAudioCraft provides the code and models for MultiBand Diffusion, [From Discrete Tokens to High Fi"
  },
  {
    "path": "docs/METRICS.md",
    "chars": 5412,
    "preview": "# AudioCraft objective metrics\n\nIn addition to training losses, AudioCraft provides a set of objective metrics\nfor audio"
  },
  {
    "path": "docs/MUSICGEN.md",
    "chars": 16299,
    "preview": "# MusicGen: Simple and Controllable Music Generation\n\nAudioCraft provides the code and models for MusicGen, [a simple an"
  },
  {
    "path": "docs/TRAINING.md",
    "chars": 14421,
    "preview": "# AudioCraft training pipelines\n\nAudioCraft training pipelines are built on top of PyTorch as our core deep learning lib"
  },
  {
    "path": "egs/example/data.jsonl",
    "chars": 288,
    "preview": "{\"path\": \"dataset/example/electro_1.mp3\", \"duration\": 15.024, \"sample_rate\": 48000, \"amplitude\": null, \"weight\": null, \""
  },
  {
    "path": "model_cards/AUDIOGEN_MODEL_CARD.md",
    "chars": 5788,
    "preview": "# AudioGen Model Card\n\n## Model details\n**Organization developing the model:** The FAIR team of Meta AI.\n\n**Model date:*"
  },
  {
    "path": "model_cards/MUSICGEN_MODEL_CARD.md",
    "chars": 6760,
    "preview": "# MusicGen Model Card\n\n## Model details\n\n**Organization developing the model:** The FAIR team of Meta AI.\n\n**Model date:"
  },
  {
    "path": "models/Put your models here.txt",
    "chars": 12,
    "preview": "nothing here"
  },
  {
    "path": "mypy.ini",
    "chars": 169,
    "preview": "[mypy]\n\n[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,t"
  },
  {
    "path": "requirements.txt",
    "chars": 337,
    "preview": "# please make sure you have already a pytorch install that is cuda enabled!\nav\neinops\nflashy>=0.0.1\nhydra-core>=1.1\nhydr"
  },
  {
    "path": "scripts/__init__.py",
    "chars": 198,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "scripts/mos.py",
    "chars": 9641,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "scripts/resample_dataset.py",
    "chars": 7478,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  },
  {
    "path": "scripts/static/style.css",
    "chars": 1475,
    "preview": "body {\n    background-color: #fbfbfb;\n    margin: 0;\n}\n\nselect, input {\n    font-size: 1em;\n    max-width: 100%;\n}\n\n.xp_"
  },
  {
    "path": "scripts/templates/base.html",
    "chars": 402,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    {% block head %}\n    <meta name=\"viewport\" content=\"width=device-width, init"
  },
  {
    "path": "scripts/templates/index.html",
    "chars": 833,
    "preview": "{% extends \"base.html\" %}\n{% block content %}\n\n<p>\n    Welcome <span class=\"special\">{{session['user']}}</span> to the i"
  },
  {
    "path": "scripts/templates/login.html",
    "chars": 469,
    "preview": "{% extends \"base.html\" %}\n{% block content %}\n\n<p>\n    You must identify yourself first! We use a highly secured protoco"
  },
  {
    "path": "scripts/templates/results.html",
    "chars": 496,
    "preview": "{% extends \"base.html\" %}\n{% block content %}\n\n<h1>Results for survey #{{signature}}</h1>\n<p>Checkout <a href=\"{{url_for"
  },
  {
    "path": "scripts/templates/survey.html",
    "chars": 5076,
    "preview": "{% extends \"base.html\" %}\n{% block content %}\n<h1>Survey #{{signature}}</h1>\n{% if success %}\n<p class=\"success\"> Your r"
  },
  {
    "path": "setup.cfg",
    "chars": 245,
    "preview": "[pep8]\nmax-line-length = 120\n\n[flake8]\nmax-line-length = 120\n\n[coverage:report]\ninclude = audiocraft/*\nomit =\n    audioc"
  },
  {
    "path": "setup.py",
    "chars": 1807,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the l"
  }
]

// ... and 27 more files (download for full content)

About this extraction

This page contains the full source code of the GrandaddyShmax/audiocraft_plus GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 227 files (1.1 MB), approximately 268.6k tokens, and a symbol index with 1239 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!