Repository: GrandaddyShmax/audiocraft_plus Branch: main Commit: c9ff851e7b50 Files: 227 Total size: 1.1 MB Directory structure: gitextract_7_a5iyu7/ ├── .github/ │ └── actions/ │ └── audiocraft_build/ │ └── action.yml ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── LICENSE_weights ├── MANIFEST.in ├── Makefile ├── README.md ├── app.py ├── audiocraft/ │ ├── __init__.py │ ├── adversarial/ │ │ ├── __init__.py │ │ ├── discriminators/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── mpd.py │ │ │ ├── msd.py │ │ │ └── msstftd.py │ │ └── losses.py │ ├── data/ │ │ ├── __init__.py │ │ ├── audio.py │ │ ├── audio_dataset.py │ │ ├── audio_utils.py │ │ ├── info_audio_dataset.py │ │ ├── music_dataset.py │ │ ├── sound_dataset.py │ │ └── zip.py │ ├── environment.py │ ├── grids/ │ │ ├── __init__.py │ │ ├── _base_explorers.py │ │ ├── audiogen/ │ │ │ ├── __init__.py │ │ │ ├── audiogen_base_16khz.py │ │ │ └── audiogen_pretrained_16khz_eval.py │ │ ├── compression/ │ │ │ ├── __init__.py │ │ │ ├── _explorers.py │ │ │ ├── debug.py │ │ │ ├── encodec_audiogen_16khz.py │ │ │ ├── encodec_base_24khz.py │ │ │ └── encodec_musicgen_32khz.py │ │ ├── diffusion/ │ │ │ ├── 4_bands_base_32khz.py │ │ │ ├── __init__.py │ │ │ └── _explorers.py │ │ └── musicgen/ │ │ ├── __init__.py │ │ ├── _explorers.py │ │ ├── musicgen_base_32khz.py │ │ ├── musicgen_base_cached_32khz.py │ │ ├── musicgen_clapemb_32khz.py │ │ ├── musicgen_melody_32khz.py │ │ └── musicgen_pretrained_32khz_eval.py │ ├── losses/ │ │ ├── __init__.py │ │ ├── balancer.py │ │ ├── sisnr.py │ │ ├── specloss.py │ │ └── stftloss.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── chroma_cosinesim.py │ │ ├── clap_consistency.py │ │ ├── fad.py │ │ ├── kld.py │ │ ├── rvm.py │ │ └── visqol.py │ ├── models/ │ │ ├── __init__.py │ │ ├── audiogen.py │ │ ├── builders.py │ │ ├── encodec.py │ │ ├── lm.py │ │ ├── loaders.py │ │ ├── multibanddiffusion.py │ │ ├── musicgen.py │ │ └── unet.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── chroma.py │ │ ├── codebooks_patterns.py │ │ ├── conditioners.py │ │ ├── conv.py │ │ ├── diffusion_schedule.py │ │ ├── lstm.py │ │ ├── rope.py │ │ ├── seanet.py │ │ ├── streaming.py │ │ └── transformer.py │ ├── optim/ │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── dadam.py │ │ ├── ema.py │ │ ├── fsdp.py │ │ ├── inverse_sqrt_lr_scheduler.py │ │ ├── linear_warmup_lr_scheduler.py │ │ └── polynomial_decay_lr_scheduler.py │ ├── py.typed │ ├── quantization/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── core_vq.py │ │ └── vq.py │ ├── solvers/ │ │ ├── __init__.py │ │ ├── audiogen.py │ │ ├── base.py │ │ ├── builders.py │ │ ├── compression.py │ │ ├── diffusion.py │ │ └── musicgen.py │ ├── train.py │ └── utils/ │ ├── __init__.py │ ├── autocast.py │ ├── best_state.py │ ├── cache.py │ ├── checkpoint.py │ ├── cluster.py │ ├── deadlock.py │ ├── export.py │ ├── export_legacy.py │ ├── notebook.py │ ├── profiler.py │ ├── samples/ │ │ ├── __init__.py │ │ └── manager.py │ ├── ui.py │ └── utils.py ├── config/ │ ├── conditioner/ │ │ ├── chroma2music.yaml │ │ ├── clapemb2music.yaml │ │ ├── none.yaml │ │ ├── text2music.yaml │ │ └── text2sound.yaml │ ├── config.yaml │ ├── dset/ │ │ ├── audio/ │ │ │ ├── audiocaps_16khz.yaml │ │ │ ├── default.yaml │ │ │ ├── example.yaml │ │ │ └── musiccaps_32khz.yaml │ │ ├── default.yaml │ │ └── internal/ │ │ ├── music_10k_32khz.yaml │ │ ├── music_400k_32khz.yaml │ │ └── sounds_16khz.yaml │ ├── model/ │ │ ├── encodec/ │ │ │ ├── default.yaml │ │ │ ├── encodec_base_causal.yaml │ │ │ ├── encodec_large_nq4_s320.yaml │ │ │ └── encodec_large_nq4_s640.yaml │ │ ├── lm/ │ │ │ ├── audiogen_lm.yaml │ │ │ ├── default.yaml │ │ │ ├── model_scale/ │ │ │ │ ├── base.yaml │ │ │ │ ├── large.yaml │ │ │ │ ├── medium.yaml │ │ │ │ ├── small.yaml │ │ │ │ └── xsmall.yaml │ │ │ └── musicgen_lm.yaml │ │ ├── none.yaml │ │ └── score/ │ │ └── basic.yaml │ ├── solver/ │ │ ├── audiogen/ │ │ │ ├── audiogen_base_16khz.yaml │ │ │ ├── debug.yaml │ │ │ ├── default.yaml │ │ │ └── evaluation/ │ │ │ ├── none.yaml │ │ │ └── objective_eval.yaml │ │ ├── compression/ │ │ │ ├── debug.yaml │ │ │ ├── default.yaml │ │ │ ├── encodec_audiogen_16khz.yaml │ │ │ ├── encodec_base_24khz.yaml │ │ │ └── encodec_musicgen_32khz.yaml │ │ ├── default.yaml │ │ ├── diffusion/ │ │ │ ├── debug.yaml │ │ │ ├── default.yaml │ │ │ └── encodec_24khz.yaml │ │ └── musicgen/ │ │ ├── debug.yaml │ │ ├── default.yaml │ │ ├── evaluation/ │ │ │ ├── none.yaml │ │ │ └── objective_eval.yaml │ │ ├── musicgen_base_32khz.yaml │ │ └── musicgen_melody_32khz.yaml │ └── teams/ │ ├── default.yaml │ └── labs.yaml ├── dataset/ │ └── example/ │ ├── electro_1.json │ └── electro_2.json ├── demos/ │ ├── audiogen_demo.ipynb │ ├── musicgen_app.py │ └── musicgen_demo.ipynb ├── dockerignore ├── docs/ │ ├── AUDIOGEN.md │ ├── CONDITIONING.md │ ├── DATASETS.md │ ├── ENCODEC.md │ ├── MBD.md │ ├── METRICS.md │ ├── MUSICGEN.md │ └── TRAINING.md ├── egs/ │ └── example/ │ └── data.jsonl ├── model_cards/ │ ├── AUDIOGEN_MODEL_CARD.md │ └── MUSICGEN_MODEL_CARD.md ├── models/ │ └── Put your models here.txt ├── mypy.ini ├── requirements.txt ├── scripts/ │ ├── __init__.py │ ├── mos.py │ ├── resample_dataset.py │ ├── static/ │ │ └── style.css │ └── templates/ │ ├── base.html │ ├── index.html │ ├── login.html │ ├── results.html │ └── survey.html ├── setup.cfg ├── setup.py └── tests/ ├── __init__.py ├── adversarial/ │ ├── __init__.py │ ├── test_discriminators.py │ └── test_losses.py ├── common_utils/ │ ├── __init__.py │ ├── temp_utils.py │ └── wav_utils.py ├── data/ │ ├── __init__.py │ ├── test_audio.py │ ├── test_audio_dataset.py │ └── test_audio_utils.py ├── losses/ │ ├── __init__.py │ └── test_losses.py ├── models/ │ ├── test_audiogen.py │ ├── test_encodec_model.py │ ├── test_multibanddiffusion.py │ └── test_musicgen.py ├── modules/ │ ├── __init__.py │ ├── test_activations.py │ ├── test_codebooks_patterns.py │ ├── test_conv.py │ ├── test_lstm.py │ ├── test_rope.py │ ├── test_seanet.py │ └── test_transformer.py ├── quantization/ │ └── test_vq.py └── utils/ └── __init__.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/actions/audiocraft_build/action.yml ================================================ name: audiocraft_build description: 'Build audiocraft env.' runs: using: "composite" steps: - uses: actions/setup-python@v2 with: python-version: 3.8 - uses: actions/cache@v2 id: cache with: path: env key: audiocraft_env-${{ hashFiles('**/requirements.txt') }} - if: ${{ steps.cache.outputs.cache-hit != 'true' }} name: Install dependencies shell: bash run: | sudo apt-get update sudo apt-get install libsndfile1-dev ffmpeg python3 -m venv env . env/bin/activate python -m pip install --upgrade pip pip install -e '.[dev]' - name: System Dependencies shell: bash run: | sudo apt-get update sudo apt-get install libsndfile1-dev ffmpeg ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__ *.py[cod] *$py.class # C extensions *.so # macOS dir files .DS_Store # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg .ipynb_checkpoints # Tests and linter .pytest_cache/ .mypy_cache/ .coverage # docs /api_docs # dotenv .env .envrc # virtualenv .venv venv/ ENV/ # egs with manifest files egs/* !egs/example # local datasets dataset/* !dataset/example # personal notebooks & scripts */local_scripts */notes .vscode/ /notebooks /local_scripts /notes /cache ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [1.0.0] - 2023-08-02 Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion. Added pretrained model for AudioGen and MultiBandDiffusion. ## [0.0.2] - 2023-08-01 Improved demo, fixed top p (thanks @jnordberg). Compressor tanh on output to avoid clipping with some style (especially piano). Now repeating the conditioning periodically if it is too short. More options when launching Gradio app locally (thanks @ashleykleynhans). Testing out PyTorch 2.0 memory efficient attention. Added extended generation (infinite length) by slowly moving the windows. Note that other implementations exist: https://github.com/camenduru/MusicGen-colab. ## [0.0.1] - 2023-06-09 Initial release, with model evaluation only. ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. This Code of Conduct also applies outside the project spaces when there is a reasonable belief that an individual's behavior may have a negative impact on the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to AudioCraft We want to make contributing to this project as easy and transparent as possible. ## Pull Requests AudioCraft is the implementation of a research paper. Therefore, we do not plan on accepting many pull requests for new features. We certainly welcome them for bug fixes. 1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Meta's open source projects. Complete your CLA here: ## Issues We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License By contributing to encodec, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. ================================================ FILE: Dockerfile ================================================ FROM nvidia/cuda:11.8.0-base-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive \ PYTHONUNBUFFERED=1 \ PYTHONIOENCODING=UTF-8 RUN --mount=type=cache,target=/var/cache/apt --mount=type=cache,target=/var/lib/apt apt update &&\ apt install -y \ wget \ git \ pkg-config \ python3 \ python3-pip \ python-is-python3 \ ffmpeg \ libnvrtc11.2 \ libtcmalloc-minimal4 RUN useradd -m -u 1000 ac RUN --mount=type=cache,target=/root/.cache python -m pip install --upgrade pip wheel ENV TORCH_COMMAND="pip install torch==2.0.1+cu118 torchaudio --extra-index-url https://download.pytorch.org/whl/cu118" RUN --mount=type=cache,target=/root/.cache python -m $TORCH_COMMAND RUN ln -s /usr/lib/x86_64-linux-gnu/libnvrtc.so.11.2 /usr/lib/x86_64-linux-gnu/libnvrtc.so USER 1000 RUN mkdir ~/.cache RUN --mount=type=cache,target=/home/ac/.cache --mount=source=.,target=/home/ac/audiocraft python -m pip install -r /home/ac/audiocraft/requirements.txt WORKDIR /home/ac/audiocraft ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) Meta Platforms, Inc. and affiliates. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: LICENSE_weights ================================================ Attribution-NonCommercial 4.0 International ======================================================================= Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. Using Creative Commons Public Licenses Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC- licensed material, or material used under an exception or limitation to copyright. More considerations for licensors: wiki.creativecommons.org/Considerations_for_licensors Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason--for example, because of any applicable exception or limitation to copyright--then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More_considerations for the public: wiki.creativecommons.org/Considerations_for_licensees ======================================================================= Creative Commons Attribution-NonCommercial 4.0 International Public License By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. Section 1 -- Definitions. a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. f. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. g. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. h. Licensor means the individual(s) or entity(ies) granting rights under this Public License. i. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. j. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. k. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. l. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. Section 2 -- Scope. a. License grant. 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: a. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and b. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 3. Term. The term of this Public License is specified in Section 6(a). 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a) (4) never produces Adapted Material. 5. Downstream recipients. a. Offer from the Licensor -- Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. b. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). b. Other rights. 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 2. Patent and trademark rights are not licensed under this Public License. 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. Section 3 -- License Conditions. Your exercise of the Licensed Rights is expressly made subject to the following conditions. a. Attribution. 1. If You Share the Licensed Material (including in modified form), You must: a. retain the following if it is supplied by the Licensor with the Licensed Material: i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); ii. a copyright notice; iii. a notice that refers to this Public License; iv. a notice that refers to the disclaimer of warranties; v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; b. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and c. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. Section 4 -- Sui Generis Database Rights. Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. Section 5 -- Disclaimer of Warranties and Limitation of Liability. a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. Section 6 -- Term and Termination. a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 2. upon express reinstatement by the Licensor. For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. Section 7 -- Other Terms and Conditions. a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. Section 8 -- Interpretation. a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. ======================================================================= Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. Creative Commons may be contacted at creativecommons.org. ================================================ FILE: MANIFEST.in ================================================ include Makefile include LICENSE include LICENSE_weights include *.md include *.ini include requirements.txt include audiocraft/py.typed include assets/*.mp3 recursive-include conf *.yaml ================================================ FILE: Makefile ================================================ INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \ dataset.train.num_samples=10 dataset.valid.num_samples=10 \ dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \ logging.level=DEBUG INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \ checkpoint.save_last=false # Using compression model from 616d7b3c default: linter tests install: pip install -U pip pip install -U -e '.[dev]' linter: flake8 audiocraft && mypy audiocraft flake8 tests && mypy tests tests: coverage run -m pytest tests coverage report tests_integ: $(INTEG_COMPRESSION) $(INTEG_MBD) $(INTEG_MUSICGEN) $(INTEG_AUDIOGEN) api_docs: pdoc3 --html -o api_docs -f audiocraft dist: python setup.py sdist .PHONY: linter tests api_docs dist ================================================ FILE: README.md ================================================ # AudioCraft Plus ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg) ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg) ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg) AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen. Open In Colab Open in HugginFace

![image](https://github.com/GrandaddyShmax/audiocraft_plus/assets/52707645/043fc037-54a9-48c4-bb5c-bf9b7440d146) ## Features AudioCraft Plus is an all-in-one WebUI for the original AudioCraft, adding many quality features on top. - AudioGen Model - Multiband Diffusion - Custom Model Support - Generation Metadata and Audio Info tab - Mono to Stereo - Multiprompt/Prompt Segmentation with Structure Prompts - Video Output Customization - Music Continuation ## Installation If you are updating from the previous version of AudioCraft Plus, do the following steps in the AudioCraft Plus folder: ```shell git pull pip install transformers --upgrade pip install torchmetrics --upgrade ``` #### Otherwise: Clean Installation AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following: ```shell # Best to make sure you have torch installed first, in particular before installing xformers. # Don't run this if you already have PyTorch installed. pip install 'torch>=2.0' # Then proceed to one of the following pip install -U audiocraft # stable release pip install -U git+https://git@github.com/GrandaddyShmax/audiocraft_plus#egg=audiocraft # bleeding edge pip install -e . # or if you cloned the repo locally (mandatory if you want to train). ``` We also recommend having `ffmpeg` installed, either through your system or Anaconda: ```bash sudo apt-get install ffmpeg # Or if you are using Anaconda or Miniconda conda install 'ffmpeg<5' -c conda-forge ``` Installation video thanks to Pogs Cafe: [![Untitled](http://img.youtube.com/vi/WjGk4bcbUOI/0.jpg)](http://www.youtube.com/watch?v=WjGk4bcbUOI "Installing MusicGen+ Locally") Additional installation guide by [radaevm](https://github.com/radaevm) can be found [HERE](https://github.com/GrandaddyShmax/audiocraft_plus/discussions/31) ## Models At the moment, AudioCraft contains the training code and inference code for: * [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model. * [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model. * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec. * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion. ## Training code AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models. For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to the [AudioCraft training documentation](./docs/TRAINING.md). For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model that provides pointers to configuration, example grids and model/task-specific information and FAQ. ## API documentation We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft. ## FAQ #### Is the training code available? Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md). #### Where are the models stored? Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable. ## License * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). * The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). ## Citation For the general framework of AudioCraft, please cite the following. ``` @article{copet2023simple, title={Simple and Controllable Music Generation}, author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, year={2023}, journal={arXiv preprint arXiv:2306.05284}, } ``` When referring to a specific model, please cite as mentioned in the model specific README, e.g [./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc. ================================================ FILE: app.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py # also released under the MIT license. import argparse from concurrent.futures import ProcessPoolExecutor import os from pathlib import Path import subprocess as sp from tempfile import NamedTemporaryFile import time import warnings import glob import re from PIL import Image from pydub import AudioSegment from datetime import datetime import json import shutil import taglib import torch import torchaudio import gradio as gr import numpy as np import typing as tp from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write from audiocraft.models import AudioGen, MusicGen, MultiBandDiffusion from audiocraft.utils import ui import random, string version = "2.0.1" theme = gr.themes.Base( primary_hue="lime", secondary_hue="lime", neutral_hue="neutral", ).set( button_primary_background_fill_hover='*primary_500', button_primary_background_fill_hover_dark='*primary_500', button_secondary_background_fill_hover='*primary_500', button_secondary_background_fill_hover_dark='*primary_500' ) MODEL = None # Last used model MODELS = None UNLOAD_MODEL = False MOVE_TO_CPU = False IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') print(IS_BATCHED) MAX_BATCH_SIZE = 12 BATCHED_DURATION = 15 INTERRUPTING = False MBD = None # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call def generate_random_string(length): characters = string.ascii_letters + string.digits return ''.join(random.choice(characters) for _ in range(length)) def resize_video(input_path, output_path, target_width, target_height): ffmpeg_cmd = [ 'ffmpeg', '-y', '-i', input_path, '-vf', f'scale={target_width}:{target_height}', '-c:a', 'copy', output_path ] sp.run(ffmpeg_cmd) def _call_nostderr(*args, **kwargs): # Avoid ffmpeg vomiting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocating the pool of processes. pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() def make_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') height = kwargs.pop('height') width = kwargs.pop('width') if height < 256: height = 256 if width < 256: width = 256 waveform_video = gr.make_waveform(*args, **kwargs) out = f"{generate_random_string(12)}.mp4" image = kwargs.get('bg_image', None) if image is None: resize_video(waveform_video, out, 900, 300) else: resize_video(waveform_video, out, width, height) print("Make a video took", time.time() - be) return out def load_model(version='GrandaddyShmax/musicgen-melody', custom_model=None, gen_type="music"): global MODEL, MODELS print("Loading model", version) if MODELS is None: if version == 'GrandaddyShmax/musicgen-custom': MODEL = MusicGen.get_pretrained(custom_model) else: if gen_type == "music": MODEL = MusicGen.get_pretrained(version) elif gen_type == "audio": MODEL = AudioGen.get_pretrained(version) return else: t1 = time.monotonic() if MODEL is not None: MODEL.to('cpu') # move to cache print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1)) t1 = time.monotonic() if version != 'GrandaddyShmax/musicgen-custom' and MODELS.get(version) is None: print("Loading model %s from disk" % version) if gen_type == "music": result = MusicGen.get_pretrained(version) elif gen_type == "audio": result = AudioGen.get_pretrained(version) MODELS[version] = result print("Model loaded in %.2fs" % (time.monotonic() - t1)) MODEL = result return result = MODELS[version].to('cuda') print("Cached model loaded in %.2fs" % (time.monotonic() - t1)) MODEL = result def get_audio_info(audio_path): if audio_path is not None: if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"): if not audio_path.name.endswith(".json"): with taglib.File(audio_path.name, save_on_exit=False) as song: if 'COMMENT' not in song.tags: return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)" json_string = song.tags['COMMENT'][0] data = json.loads(json_string) global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else "" bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else "" key = str("\nKey: " + data['key']) if 'key' in data else "" scale = str("\nScale: " + data['scale']) if 'scale' in data else "" prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else "" duration = str("\nDuration: " + data['duration']) if 'duration' in data else "" overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else "" seed = str("\nSeed: " + data['seed']) if 'seed' in data else "" audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else "" input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else "" channel = str("\nChannel: " + data['channel']) if 'channel' in data else "" sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else "" gen_type = str(data['generator'] + "gen-") if 'generator' in data else "" model = str("\nModel: " + gen_type + data['model']) if 'model' in data else "" custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else "" decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else "" topk = str("\nTopk: " + data['topk']) if 'topk' in data else "" topp = str("\nTopp: " + data['topp']) if 'topp' in data else "" temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else "" cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else "" version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown" info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + decoder + topk + topp + temperature + cfg_coef) if info == "": return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted. (Discord removes metadata from mp4 and wav files, so you can't use them)" return info else: with open(audio_path.name) as json_file: data = json.load(json_file) #if 'global_prompt' not in data: #return "No tags found. Either the file is not generated by MusicGen+ V1.2.8a and higher or the tags are corrupted." global_prompt = str("\nGlobal Prompt: " + (data['global_prompt'] if data['global_prompt'] != "" else "none")) if 'global_prompt' in data else "" bpm = str("\nBPM: " + data['bpm']) if 'bpm' in data else "" key = str("\nKey: " + data['key']) if 'key' in data else "" scale = str("\nScale: " + data['scale']) if 'scale' in data else "" prompts = str("\nPrompts: " + (data['texts'] if data['texts'] != "['']" else "none")) if 'texts' in data else "" duration = str("\nDuration: " + data['duration']) if 'duration' in data else "" overlap = str("\nOverlap: " + data['overlap']) if 'overlap' in data else "" seed = str("\nSeed: " + data['seed']) if 'seed' in data else "" audio_mode = str("\nAudio Mode: " + data['audio_mode']) if 'audio_mode' in data else "" input_length = str("\nInput Length: " + data['input_length']) if 'input_length' in data else "" channel = str("\nChannel: " + data['channel']) if 'channel' in data else "" sr_select = str("\nSample Rate: " + data['sr_select']) if 'sr_select' in data else "" gen_type = str(data['generator'] + "gen-") if 'generator' in data else "" model = str("\nModel: " + gen_type + data['model']) if 'model' in data else "" custom_model = str("\nCustom Model: " + data['custom_model']) if 'custom_model' in data else "" decoder = str("\nDecoder: " + data['decoder']) if 'decoder' in data else "" topk = str("\nTopk: " + data['topk']) if 'topk' in data else "" topp = str("\nTopp: " + data['topp']) if 'topp' in data else "" temperature = str("\nTemperature: " + data['temperature']) if 'temperature' in data else "" cfg_coef = str("\nClassifier Free Guidance: " + data['cfg_coef']) if 'cfg_coef' in data else "" version = str("Version: " + data['version']) if 'version' in data else "Version: Unknown" info = str(version + global_prompt + bpm + key + scale + prompts + duration + overlap + seed + audio_mode + input_length + channel + sr_select + model + custom_model + decoder + topk + topp + temperature + cfg_coef) if info == "": return "No tags found. Either the file is not generated by MusicGen+ V1.2.7 and higher or the tags are corrupted." return info else: return "Only .wav ,.mp4 and .json files are supported" else: return None def info_to_params(audio_path): if audio_path is not None: if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"): if not audio_path.name.endswith(".json"): with taglib.File(audio_path.name, save_on_exit=False) as song: if 'COMMENT' not in song.tags: return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000" json_string = song.tags['COMMENT'][0] data = json.loads(json_string) struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False global_prompt = data['global_prompt'] if 'global_prompt' in data else "" bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120 key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C" scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major" model = data['model'] if 'model' in data else "large" custom_model = (data['custom_model'] if (data['custom_model']) in get_available_folders() else None) if 'custom_model' in data else None decoder = data['decoder'] if 'decoder' in data else "Default" if 'texts' not in data: unique_prompts = 1 text = ["", "", "", "", "", "", "", "", "", ""] repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] else: s = data['texts'] s = re.findall(r"'(.*?)'", s) text = [] repeat = [] i = 0 for elem in s: if elem.strip(): if i == 0 or elem != s[i-1]: text.append(elem) repeat.append(1) else: repeat[-1] += 1 i += 1 text.extend([""] * (10 - len(text))) repeat.extend([1] * (10 - len(repeat))) unique_prompts = len([t for t in text if t]) audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample" duration = int(data['duration']) if 'duration' in data else 10 topk = float(data['topk']) if 'topk' in data else 250 topp = float(data['topp']) if 'topp' in data else 0 temperature = float(data['temperature']) if 'temperature' in data else 1.0 cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0 seed = int(data['seed']) if 'seed' in data else -1 overlap = int(data['overlap']) if 'overlap' in data else 12 channel = data['channel'] if 'channel' in data else "stereo" sr_select = data['sr_select'] if 'sr_select' in data else "48000" return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select else: with open(audio_path.name) as json_file: data = json.load(json_file) struc_prompt = (False if data['bpm'] == "none" else True) if 'bpm' in data else False global_prompt = data['global_prompt'] if 'global_prompt' in data else "" bpm = (120 if data['bpm'] == "none" else int(data['bpm'])) if 'bpm' in data else 120 key = ("C" if data['key'] == "none" else data['key']) if 'key' in data else "C" scale = ("Major" if data['scale'] == "none" else data['scale']) if 'scale' in data else "Major" model = data['model'] if 'model' in data else "large" custom_model = (data['custom_model'] if data['custom_model'] in get_available_folders() else None) if 'custom_model' in data else None decoder = data['decoder'] if 'decoder' in data else "Default" if 'texts' not in data: unique_prompts = 1 text = ["", "", "", "", "", "", "", "", "", ""] repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] else: s = data['texts'] s = re.findall(r"'(.*?)'", s) text = [] repeat = [] i = 0 for elem in s: if elem.strip(): if i == 0 or elem != s[i-1]: text.append(elem) repeat.append(1) else: repeat[-1] += 1 i += 1 text.extend([""] * (10 - len(text))) repeat.extend([1] * (10 - len(repeat))) unique_prompts = len([t for t in text if t]) audio_mode = ("sample" if data['audio_mode'] == "none" else data['audio_mode']) if 'audio_mode' in data else "sample" duration = int(data['duration']) if 'duration' in data else 10 topk = float(data['topk']) if 'topk' in data else 250 topp = float(data['topp']) if 'topp' in data else 0 temperature = float(data['temperature']) if 'temperature' in data else 1.0 cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0 seed = int(data['seed']) if 'seed' in data else -1 overlap = int(data['overlap']) if 'overlap' in data else 12 channel = data['channel'] if 'channel' in data else "stereo" sr_select = data['sr_select'] if 'sr_select' in data else "48000" return decoder, struc_prompt, global_prompt, bpm, key, scale, model, custom_model, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], audio_mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select else: return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000" else: return "Default", False, "", 120, "C", "Major", "large", None, 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "sample", 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000" def info_to_params_a(audio_path): if audio_path is not None: if audio_path.name.endswith(".wav") or audio_path.name.endswith(".mp4") or audio_path.name.endswith(".json"): if not audio_path.name.endswith(".json"): with taglib.File(audio_path.name, save_on_exit=False) as song: if 'COMMENT' not in song.tags: return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000" json_string = song.tags['COMMENT'][0] data = json.loads(json_string) struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False global_prompt = data['global_prompt'] if 'global_prompt' in data else "" decoder = data['decoder'] if 'decoder' in data else "Default" if 'texts' not in data: unique_prompts = 1 text = ["", "", "", "", "", "", "", "", "", ""] repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] else: s = data['texts'] s = re.findall(r"'(.*?)'", s) text = [] repeat = [] i = 0 for elem in s: if elem.strip(): if i == 0 or elem != s[i-1]: text.append(elem) repeat.append(1) else: repeat[-1] += 1 i += 1 text.extend([""] * (10 - len(text))) repeat.extend([1] * (10 - len(repeat))) unique_prompts = len([t for t in text if t]) duration = int(data['duration']) if 'duration' in data else 10 topk = float(data['topk']) if 'topk' in data else 250 topp = float(data['topp']) if 'topp' in data else 0 temperature = float(data['temperature']) if 'temperature' in data else 1.0 cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0 seed = int(data['seed']) if 'seed' in data else -1 overlap = int(data['overlap']) if 'overlap' in data else 12 channel = data['channel'] if 'channel' in data else "stereo" sr_select = data['sr_select'] if 'sr_select' in data else "48000" return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select else: with open(audio_path.name) as json_file: data = json.load(json_file) struc_prompt = (False if data['global_prompt'] == "" else True) if 'global_prompt' in data else False global_prompt = data['global_prompt'] if 'global_prompt' in data else "" decoder = data['decoder'] if 'decoder' in data else "Default" if 'texts' not in data: unique_prompts = 1 text = ["", "", "", "", "", "", "", "", "", ""] repeat = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] else: s = data['texts'] s = re.findall(r"'(.*?)'", s) text = [] repeat = [] i = 0 for elem in s: if elem.strip(): if i == 0 or elem != s[i-1]: text.append(elem) repeat.append(1) else: repeat[-1] += 1 i += 1 text.extend([""] * (10 - len(text))) repeat.extend([1] * (10 - len(repeat))) unique_prompts = len([t for t in text if t]) duration = int(data['duration']) if 'duration' in data else 10 topk = float(data['topk']) if 'topk' in data else 250 topp = float(data['topp']) if 'topp' in data else 0 temperature = float(data['temperature']) if 'temperature' in data else 1.0 cfg_coef = float(data['cfg_coef']) if 'cfg_coef' in data else 5.0 seed = int(data['seed']) if 'seed' in data else -1 overlap = int(data['overlap']) if 'overlap' in data else 12 channel = data['channel'] if 'channel' in data else "stereo" sr_select = data['sr_select'] if 'sr_select' in data else "48000" return decoder, struc_prompt, global_prompt, unique_prompts, text[0], text[1], text[2], text[3], text[4], text[5], text[6], text[7], text[8], text[9], repeat[0], repeat[1], repeat[2], repeat[3], repeat[4], repeat[5], repeat[6], repeat[7], repeat[8], repeat[9], duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select else: return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000" else: return "Default", False, "", 1, "", "", "", "", "", "", "", "", "", "", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 250, 0, 1.0, 5.0, -1, 12, "stereo", "48000" def make_pseudo_stereo (filename, sr_select, pan, delay): if pan: temp = AudioSegment.from_wav(filename) if sr_select != "32000": temp = temp.set_frame_rate(int(sr_select)) left = temp.pan(-0.5) - 5 right = temp.pan(0.6) - 5 temp = left.overlay(right, position=5) temp.export(filename, format="wav") if delay: waveform, sample_rate = torchaudio.load(filename) # load mono WAV file delay_seconds = 0.01 # set delay 10ms delay_samples = int(delay_seconds * sample_rate) # Calculating delay value in number of samples stereo_waveform = torch.stack([waveform[0], torch.cat((torch.zeros(delay_samples), waveform[0][:-delay_samples]))]) # Generate a stereo file with original mono audio and delayed version torchaudio.save(filename, stereo_waveform, sample_rate) return def normalize_audio(audio_data): audio_data = audio_data.astype(np.float32) max_value = np.max(np.abs(audio_data)) audio_data /= max_value return audio_data def load_diffusion(): global MBD if MBD is None: print("loading MBD") MBD = MultiBandDiffusion.get_mbd_musicgen() def unload_diffusion(): global MBD if MBD is not None: print("unloading MBD") MBD = None def _do_predictions(gen_type, texts, melodies, sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=False, **gen_kwargs): if gen_type == "music": maximum_size = 29.5 elif gen_type == "audio": maximum_size = 9.5 cut_size = 0 input_length = 0 sampleP = None if sample is not None: globalSR, sampleM = sample[0], sample[1] sampleM = normalize_audio(sampleM) sampleM = torch.from_numpy(sampleM).t() if sampleM.dim() == 1: sampleM = sampleM.unsqueeze(0) sample_length = sampleM.shape[sampleM.dim() - 1] / globalSR if trim_start >= sample_length: trim_start = sample_length - 0.5 if trim_end >= sample_length: trim_end = sample_length - 0.5 if trim_start + trim_end >= sample_length: tmp = sample_length - 0.5 trim_start = tmp / 2 trim_end = tmp / 2 sampleM = sampleM[..., int(globalSR * trim_start):int(globalSR * (sample_length - trim_end))] sample_length = sample_length - (trim_start + trim_end) if sample_length > maximum_size: cut_size = sample_length - maximum_size sampleP = sampleM[..., :int(globalSR * cut_size)] sampleM = sampleM[..., int(globalSR * cut_size):] if sample_length >= duration: duration = sample_length + 0.5 input_length = sample_length global MODEL MODEL.set_generation_params(duration=(duration - cut_size), **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies], [None if sample is None else (sample[0], sample[1].shape)]) be = time.time() processed_melodies = [] if gen_type == "music": target_sr = 32000 elif gen_type == "audio": target_sr = 16000 target_ac = 1 for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) if sample is not None: if sampleP is None: if gen_type == "music": outputs = MODEL.generate_continuation( prompt=sampleM, prompt_sample_rate=globalSR, descriptions=texts, progress=progress, return_tokens=USE_DIFFUSION ) elif gen_type == "audio": outputs = MODEL.generate_continuation( prompt=sampleM, prompt_sample_rate=globalSR, descriptions=texts, progress=progress ) else: if sampleP.dim() > 1: sampleP = convert_audio(sampleP, globalSR, target_sr, target_ac) sampleP = sampleP.to(MODEL.device).float().unsqueeze(0) if gen_type == "music": outputs = MODEL.generate_continuation( prompt=sampleM, prompt_sample_rate=globalSR, descriptions=texts, progress=progress, return_tokens=USE_DIFFUSION ) elif gen_type == "audio": outputs = MODEL.generate_continuation( prompt=sampleM, prompt_sample_rate=globalSR, descriptions=texts, progress=progress ) outputs = torch.cat([sampleP, outputs], 2) elif any(m is not None for m in processed_melodies): if gen_type == "music": outputs = MODEL.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, return_tokens=USE_DIFFUSION ) elif gen_type == "audio": outputs = MODEL.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress ) else: if gen_type == "music": outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION) elif gen_type == "audio": outputs = MODEL.generate(texts, progress=progress) if USE_DIFFUSION: print("outputs: " + str(outputs)) outputs_diffusion = MBD.tokens_to_wav(outputs[1]) outputs = torch.cat([outputs[0], outputs_diffusion], dim=0) outputs = outputs.detach().cpu().float() backups = outputs if channel == "stereo": outputs = convert_audio(outputs, target_sr, int(sr_select), 2) elif channel == "mono" and sr_select != "32000": outputs = convert_audio(outputs, target_sr, int(sr_select), 1) out_files = [] out_audios = [] out_backup = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, (MODEL.sample_rate if channel == "stereo effect" else int(sr_select)), strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) if channel == "stereo effect": make_pseudo_stereo(file.name, sr_select, pan=True, delay=True); out_files.append(pool.submit(make_waveform, file.name, bg_image=image, bg_color=background, bars_color=(bar1, bar2), fg_alpha=1.0, bar_count=75, height=height, width=width)) out_audios.append(file.name) file_cleaner.add(file.name) print(f'wav: {file.name}') for backup in backups: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, backup, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_backup.append(file.name) file_cleaner.add(file.name) res = [out_file.result() for out_file in out_files] res_audio = out_audios res_backup = out_backup for file in res: file_cleaner.add(file) print(f'video: {file}') print("batch finished", len(texts), time.time() - be) print("Tempfiles currently stored: ", len(file_cleaner.files)) if MOVE_TO_CPU: MODEL.to('cpu') if UNLOAD_MODEL: MODEL = None torch.cuda.empty_cache() torch.cuda.ipc_collect() return res, res_audio, res_backup, input_length def predict_batched(texts, melodies): max_text_length = 512 texts = [text[:max_text_length] for text in texts] load_model('melody') res = _do_predictions(texts, melodies, BATCHED_DURATION) return res def add_tags(filename, tags): json_string = None data = { "global_prompt": tags[0], "bpm": tags[1], "key": tags[2], "scale": tags[3], "texts": tags[4], "duration": tags[5], "overlap": tags[6], "seed": tags[7], "audio_mode": tags[8], "input_length": tags[9], "channel": tags[10], "sr_select": tags[11], "model": tags[12], "custom_model": tags[13], "decoder": tags[14], "topk": tags[15], "topp": tags[16], "temperature": tags[17], "cfg_coef": tags[18], "generator": tags[19], "version": version } json_string = json.dumps(data) if os.path.exists(filename): with taglib.File(filename, save_on_exit=True) as song: song.tags = {'COMMENT': json_string } json_file = open(tags[7] + '.json', 'w') json_file.write(json_string) json_file.close() return json_file.name; def save_outputs(mp4, wav_tmp, tags, gen_type): # mp4: .mp4 file name in root running folder of app.py # wav_tmp: temporary wav file located in %TEMP% folder # seed - used seed # exanple BgnJtr4Pn1AJ.mp4, C:\Users\Alex\AppData\Local\Temp\tmp4ermrebs.wav, 195123182343465 # procedure read generated .mp4 and wav files, rename it by using seed as name, # and will store it to ./output/today_date/wav and ./output/today_date/mp4 folders. # if file with same seed number already exist its make postfix in name like seed(n) # where is n - consiqunce number 1-2-3-4 and so on # then we store generated mp4 and wav into destination folders. current_date = datetime.now().strftime("%Y%m%d") wav_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'wav') mp4_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'mp4') json_directory = os.path.join(os.getcwd(), 'output', current_date, gen_type,'json') os.makedirs(wav_directory, exist_ok=True) os.makedirs(mp4_directory, exist_ok=True) os.makedirs(json_directory, exist_ok=True) filename = str(tags[7]) + '.wav' target = os.path.join(wav_directory, filename) counter = 1 while os.path.exists(target): filename = str(tags[7]) + f'({counter})' + '.wav' target = os.path.join(wav_directory, filename) counter += 1 shutil.copyfile(wav_tmp, target); # make copy of original file json_file = add_tags(target, tags); wav_target=target; target=target.replace('wav', 'mp4'); mp4_target=target; mp4=r'./' +mp4; shutil.copyfile(mp4, target); # make copy of original file _ = add_tags(target, tags); target=target.replace('mp4', 'json'); # change the extension to json json_target=target; # store the json target with open(target, 'w') as f: # open a writable file object shutil.copyfile(json_file, target); # make copy of original file os.remove(json_file) return wav_target, mp4_target, json_target; def clear_cash(): # delete all temporary files genegated my system current_date = datetime.now().date() current_directory = os.getcwd() files = glob.glob(os.path.join(current_directory, '*.mp4')) for file in files: creation_date = datetime.fromtimestamp(os.path.getctime(file)).date() if creation_date == current_date: os.remove(file) temp_directory = os.environ.get('TEMP') files = glob.glob(os.path.join(temp_directory, 'tmp*.mp4')) for file in files: creation_date = datetime.fromtimestamp(os.path.getctime(file)).date() if creation_date == current_date: os.remove(file) files = glob.glob(os.path.join(temp_directory, 'tmp*.wav')) for file in files: creation_date = datetime.fromtimestamp(os.path.getctime(file)).date() if creation_date == current_date: os.remove(file) files = glob.glob(os.path.join(temp_directory, 'tmp*.png')) for file in files: creation_date = datetime.fromtimestamp(os.path.getctime(file)).date() if creation_date == current_date: os.remove(file) return def s2t(seconds, seconds2): # convert seconds to time format # seconds - time in seconds # return time in format 00:00 m, s = divmod(seconds, 60) m2, s2 = divmod(seconds2, 60) if seconds != 0 and seconds < seconds2: s = s + 1 return ("%02d:%02d - %02d:%02d" % (m, s, m2, s2)) def calc_time(gen_type, s, duration, overlap, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9): # calculate the time of generation # overlap - overlap in seconds # d0-d9 - drag # return time in seconds d_amount = [int(d0), int(d1), int(d2), int(d3), int(d4), int(d5), int(d6), int(d7), int(d8), int(d9)] calc = [] tracks = [] time = 0 s = s - 1 max_time = duration max_limit = 0 if gen_type == "music": max_limit = 30 elif gen_type == "audio": max_limit = 10 track_add = max_limit - overlap tracks.append(max_limit + ((d_amount[0] - 1) * track_add)) for i in range(1, 10): tracks.append(d_amount[i] * track_add) if tracks[0] >= max_time or s == 0: calc.append(s2t(time, max_time)) time = max_time else: calc.append(s2t(time, tracks[0])) time = tracks[0] for i in range(1, 10): if time + tracks[i] >= max_time or i == s: calc.append(s2t(time, max_time)) time = max_time else: calc.append(s2t(time, time + tracks[i])) time = time + tracks[i] return calc[0], calc[1], calc[2], calc[3], calc[4], calc[5], calc[6], calc[7], calc[8], calc[9] def predict_full(gen_type, model, decoder, custom_model, prompt_amount, struc_prompt, bpm, key, scale, global_prompt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select, progress=gr.Progress()): global INTERRUPTING global USE_DIFFUSION INTERRUPTING = False if gen_type == "audio": custom_model = None custom_model_shrt = "none" elif gen_type == "music": custom_model_shrt = custom_model custom_model = "models/" + custom_model if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: raise gr.Error("Topk must be non-negative.") if topp < 0: raise gr.Error("Topp must be non-negative.") if trim_start < 0: trim_start = 0 if trim_end < 0: trim_end = 0 topk = int(topk) if decoder == "MultiBand_Diffusion": USE_DIFFUSION = True load_diffusion() else: USE_DIFFUSION = False unload_diffusion() if gen_type == "music": model_shrt = model model = "GrandaddyShmax/musicgen-" + model elif gen_type == "audio": model_shrt = model model = "GrandaddyShmax/audiogen-" + model if MODEL is None or MODEL.name != (model): load_model(model, custom_model, gen_type) else: if MOVE_TO_CPU: MODEL.to('cuda') if seed < 0: seed = random.randint(0, 0xffff_ffff_ffff) torch.manual_seed(seed) def _progress(generated, to_generate): progress((min(generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) audio_mode = "none" melody = None sample = None if audio: audio_mode = mode if mode == "sample": sample = audio elif mode == "melody": melody = audio custom_model_shrt = "none" if model != "GrandaddyShmax/musicgen-custom" else custom_model_shrt text_cat = [p0, p1, p2, p3, p4, p5, p6, p7, p8, p9] drag_cat = [d0, d1, d2, d3, d4, d5, d6, d7, d8, d9] texts = [] raw_texts = [] ind = 0 ind2 = 0 while ind < prompt_amount: for ind2 in range(int(drag_cat[ind])): if not struc_prompt: texts.append(text_cat[ind]) global_prompt = "none" bpm = "none" key = "none" scale = "none" raw_texts.append(text_cat[ind]) else: if gen_type == "music": bpm_str = str(bpm) + " bpm" key_str = ", " + str(key) + " " + str(scale) global_str = (", " + str(global_prompt)) if str(global_prompt) != "" else "" elif gen_type == "audio": bpm_str = "" key_str = "" global_str = (str(global_prompt)) if str(global_prompt) != "" else "" texts_str = (", " + str(text_cat[ind])) if str(text_cat[ind]) != "" else "" texts.append(bpm_str + key_str + global_str + texts_str) raw_texts.append(text_cat[ind]) ind2 = 0 ind = ind + 1 outs, outs_audio, outs_backup, input_length = _do_predictions( gen_type, [texts], [melody], sample, trim_start, trim_end, duration, image, height, width, background, bar1, bar2, channel, sr_select, progress=True, top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, extend_stride=MODEL.max_duration-overlap) tags = [str(global_prompt), str(bpm), str(key), str(scale), str(raw_texts), str(duration), str(overlap), str(seed), str(audio_mode), str(input_length), str(channel), str(sr_select), str(model_shrt), str(custom_model_shrt), str(decoder), str(topk), str(topp), str(temperature), str(cfg_coef), str(gen_type)] wav_target, mp4_target, json_target = save_outputs(outs[0], outs_audio[0], tags, gen_type); # Removes the temporary files. for out in outs: os.remove(out) for out in outs_audio: os.remove(out) return mp4_target, wav_target, outs_backup[0], [mp4_target, wav_target, json_target], seed max_textboxes = 10 #def get_available_models(): #return sorted([re.sub('.pt$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('.pt')]) def get_available_folders(): models_dir = "models" folders = [f for f in os.listdir(models_dir) if os.path.isdir(os.path.join(models_dir, f))] return sorted(folders) def toggle_audio_src(choice): if choice == "mic": return gr.update(source="microphone", value=None, label="Microphone") else: return gr.update(source="upload", value=None, label="File") def ui_full(launch_kwargs): with gr.Blocks(title='AudioCraft Plus', theme=theme) as interface: gr.Markdown( """ # AudioCraft Plus - v2.0.1 ### An All-in-One AudioCraft WebUI Thanks to: facebookresearch, Camenduru, rkfg, oobabooga, AlexHK and GrandaddyShmax """ ) with gr.Tab("MusicGen"): gr.Markdown( """ ### MusicGen """ ) with gr.Row(): with gr.Column(): with gr.Tab("Generation"): with gr.Accordion("Structure Prompts", open=False): with gr.Column(): with gr.Row(): struc_prompts = gr.Checkbox(label="Enable", value=False, interactive=True, container=False) bpm = gr.Number(label="BPM", value=120, interactive=True, scale=1, precision=0) key = gr.Dropdown(["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "Bb", "B"], label="Key", value="C", interactive=True) scale = gr.Dropdown(["Major", "Minor"], label="Scale", value="Major", interactive=True) with gr.Row(): global_prompt = gr.Text(label="Global Prompt", interactive=True, scale=3) with gr.Row(): s = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2) #s_mode = gr.Radio(["segmentation", "batch"], value="segmentation", interactive=True, scale=1, label="Generation Mode") with gr.Column(): textboxes = [] prompts = [] repeats = [] calcs = [] with gr.Row(): text0 = gr.Text(label="Input Text", interactive=True, scale=4) prompts.append(text0) drag0 = gr.Number(label="Repeat", value=1, interactive=True, scale=1) repeats.append(drag0) calc0 = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time") calcs.append(calc0) for i in range(max_textboxes): with gr.Row(visible=False) as t: text = gr.Text(label="Input Text", interactive=True, scale=3) repeat = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1) calc = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time") textboxes.append(t) prompts.append(text) repeats.append(repeat) calcs.append(calc) to_calc = gr.Button("Calculate Timings", variant="secondary") with gr.Row(): duration = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True) with gr.Row(): overlap = gr.Slider(minimum=1, maximum=29, value=12, step=1, label="Overlap", interactive=True) with gr.Row(): seed = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True) gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed], queue=False) reuse_seed = gr.Button('\u267b\ufe0f', scale=1) with gr.Tab("Audio"): with gr.Row(): with gr.Column(): input_type = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True) mode = gr.Radio(["melody", "sample"], label="Input Audio Mode (optional)", value="sample", interactive=True) with gr.Row(): trim_start = gr.Number(label="Trim Start", value=0, interactive=True) trim_end = gr.Number(label="Trim End", value=0, interactive=True) audio = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True) with gr.Tab("Customization"): with gr.Row(): with gr.Column(): background = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0) bar1 = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0) bar2 = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0) with gr.Column(): image = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4) with gr.Row(): height = gr.Number(label="Height", value=512, interactive=True) width = gr.Number(label="Width", value=768, interactive=True) with gr.Tab("Settings"): with gr.Row(): channel = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1) sr_select = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True) with gr.Row(): model = gr.Radio(["melody", "small", "medium", "large", "custom"], label="Model", value="large", interactive=True, scale=1) with gr.Column(): dropdown = gr.Dropdown(choices=get_available_folders(), value=("No models found" if len(get_available_folders()) < 1 else get_available_folders()[0]), label='Custom Model (models folder)', elem_classes='slim-dropdown', interactive=True) ui.create_refresh_button(dropdown, lambda: None, lambda: {'choices': get_available_folders()}, 'refresh-button') with gr.Row(): decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True) with gr.Row(): topk = gr.Number(label="Top-k", value=250, interactive=True) topp = gr.Number(label="Top-p", value=0, interactive=True) temperature = gr.Number(label="Temperature", value=1.0, interactive=True) cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) with gr.Row(): submit = gr.Button("Generate", variant="primary") # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Column() as c: with gr.Tab("Output"): output = gr.Video(label="Generated Music", scale=0) with gr.Row(): audio_only = gr.Audio(type="numpy", label="Audio Only", interactive=False) backup_only = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False) send_audio = gr.Button("Send to Input Audio") seed_used = gr.Number(label='Seed used', value=-1, interactive=False) download = gr.File(label="Generated Files", interactive=False) with gr.Tab("Wiki"): gr.Markdown( """ - **[Generate (button)]:** Generates the music with the given settings and prompts. - **[Interrupt (button)]:** Stops the music generation as soon as it can, providing an incomplete output. --- ### Generation Tab: #### Structure Prompts: This feature helps reduce repetetive prompts by allowing you to set global prompts that will be used for all prompt segments. - **[Structure Prompts (checkbox)]:** Enable/Disable the structure prompts feature. - **[BPM (number)]:** Beats per minute of the generated music. - **[Key (dropdown)]:** The key of the generated music. - **[Scale (dropdown)]:** The scale of the generated music. - **[Global Prompt (text)]:** Here write the prompt that you wish to be used for all prompt segments. #### Multi-Prompt: This feature allows you to control the music, adding variation to different time segments. You have up to 10 prompt segments. the first prompt will always be 30s long the other prompts will be [30s - overlap]. for example if the overlap is 10s, each prompt segment will be 20s. - **[Prompt Segments (number)]:** Amount of unique prompt to generate throughout the music generation. - **[Prompt/Input Text (prompt)]:** Here describe the music you wish the model to generate. - **[Repeat (number)]:** Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt). - **[Time (text)]:** The time of the prompt segment. - **[Calculate Timings (button)]:** Calculates the timings of the prompt segments. - **[Duration (number)]:** How long you want the generated music to be (in seconds). - **[Overlap (number)]:** How much each new segment will reference the previous segment (in seconds). For example, if you choose 20s: Each new segment after the first one will reference the previous segment 20s and will generate only 10s of new music. The model can only process 30s of music. - **[Seed (number)]:** Your generated music id. If you wish to generate the exact same music, place the exact seed with the exact prompts (This way you can also extend specific song that was generated short). - **[Random Seed (button)]:** Gives "-1" as a seed, which counts as a random seed. - **[Copy Previous Seed (button)]:** Copies the seed from the output seed (if you don't feel like doing it manualy). --- ### Audio Tab: - **[Input Type (selection)]:** `File` mode allows you to upload an audio file to use as input `Mic` mode allows you to use your microphone as input - **[Input Audio Mode (selection)]:** `Melody` mode only works with the melody model: it conditions the music generation to reference the melody `Sample` mode works with any model: it gives a music sample to the model to generate its continuation. - **[Trim Start and Trim End (numbers)]:** `Trim Start` set how much you'd like to trim the input audio from the start `Trim End` same as the above but from the end - **[Input Audio (audio file)]:** Input here the audio you wish to use with "melody" or "sample" mode. --- ### Customization Tab: - **[Background Color (color)]:** Works only if you don't upload image. Color of the background of the waveform. - **[Bar Color Start (color)]:** First color of the waveform bars. - **[Bar Color End (color)]:** Second color of the waveform bars. - **[Background Image (image)]:** Background image that you wish to be attached to the generated video along with the waveform. - **[Height and Width (numbers)]:** Output video resolution, only works with image. (minimum height and width is 256). --- ### Settings Tab: - **[Output Audio Channels (selection)]:** With this you can select the amount of channels that you wish for your output audio. `mono` is a straightforward single channel audio `stereo` is a dual channel audio but it will sound more or less like mono `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio. - **[Output Audio Sample Rate (dropdown)]:** The output audio sample rate, the model default is 32000. - **[Model (selection)]:** Here you can choose which model you wish to use: `melody` model is based on the medium model with a unique feature that lets you use melody conditioning `small` model is trained on 300M parameters `medium` model is trained on 1.5B parameters `large` model is trained on 3.3B parameters `custom` model runs the custom model that you provided. - **[Custom Model (selection)]:** This dropdown will show you models that are placed in the `models` folder you must select `custom` in the model options in order to use it. - **[Refresh (button)]:** Refreshes the dropdown list for custom model. - **[Decoder (selection)]:** Choose here the decoder that you wish to use: `Default` is the default decoder `MultiBand_Diffusion` is a decoder that uses diffusion to generate the audio. - **[Top-k (number)]:** is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music. - **[Top-p (number)]:** also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities. - **[Temperature (number)]:** is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music. - **[Classifier Free Guidance (number)]:** refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture. """ ) with gr.Tab("AudioGen"): gr.Markdown( """ ### AudioGen """ ) with gr.Row(): with gr.Column(): with gr.Tab("Generation"): with gr.Accordion("Structure Prompts", open=False): with gr.Row(): struc_prompts_a = gr.Checkbox(label="Enable", value=False, interactive=True, container=False) global_prompt_a = gr.Text(label="Global Prompt", interactive=True, scale=3) with gr.Row(): s_a = gr.Slider(1, max_textboxes, value=1, step=1, label="Prompts:", interactive=True, scale=2) with gr.Column(): textboxes_a = [] prompts_a = [] repeats_a = [] calcs_a = [] with gr.Row(): text0_a = gr.Text(label="Input Text", interactive=True, scale=4) prompts_a.append(text0_a) drag0_a = gr.Number(label="Repeat", value=1, interactive=True, scale=1) repeats_a.append(drag0_a) calc0_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time") calcs_a.append(calc0_a) for i in range(max_textboxes): with gr.Row(visible=False) as t_a: text_a = gr.Text(label="Input Text", interactive=True, scale=3) repeat_a = gr.Number(label="Repeat", minimum=1, value=1, interactive=True, scale=1) calc_a = gr.Text(interactive=False, value="00:00 - 00:00", scale=1, label="Time") textboxes_a.append(t_a) prompts_a.append(text_a) repeats_a.append(repeat_a) calcs_a.append(calc_a) to_calc_a = gr.Button("Calculate Timings", variant="secondary") with gr.Row(): duration_a = gr.Slider(minimum=1, maximum=300, value=10, step=1, label="Duration", interactive=True) with gr.Row(): overlap_a = gr.Slider(minimum=1, maximum=9, value=2, step=1, label="Overlap", interactive=True) with gr.Row(): seed_a = gr.Number(label="Seed", value=-1, scale=4, precision=0, interactive=True) gr.Button('\U0001f3b2\ufe0f', scale=1).click(fn=lambda: -1, outputs=[seed_a], queue=False) reuse_seed_a = gr.Button('\u267b\ufe0f', scale=1) with gr.Tab("Audio"): with gr.Row(): with gr.Column(): input_type_a = gr.Radio(["file", "mic"], value="file", label="Input Type (optional)", interactive=True) mode_a = gr.Radio(["sample"], label="Input Audio Mode (optional)", value="sample", interactive=False, visible=False) with gr.Row(): trim_start_a = gr.Number(label="Trim Start", value=0, interactive=True) trim_end_a = gr.Number(label="Trim End", value=0, interactive=True) audio_a = gr.Audio(source="upload", type="numpy", label="Input Audio (optional)", interactive=True) with gr.Tab("Customization"): with gr.Row(): with gr.Column(): background_a = gr.ColorPicker(value="#0f0f0f", label="background color", interactive=True, scale=0) bar1_a = gr.ColorPicker(value="#84cc16", label="bar color start", interactive=True, scale=0) bar2_a = gr.ColorPicker(value="#10b981", label="bar color end", interactive=True, scale=0) with gr.Column(): image_a = gr.Image(label="Background Image", type="filepath", interactive=True, scale=4) with gr.Row(): height_a = gr.Number(label="Height", value=512, interactive=True) width_a = gr.Number(label="Width", value=768, interactive=True) with gr.Tab("Settings"): with gr.Row(): channel_a = gr.Radio(["mono", "stereo", "stereo effect"], label="Output Audio Channels", value="stereo", interactive=True, scale=1) sr_select_a = gr.Dropdown(["11025", "16000", "22050", "24000", "32000", "44100", "48000"], label="Output Audio Sample Rate", value="48000", interactive=True) with gr.Row(): model_a = gr.Radio(["medium"], label="Model", value="medium", interactive=False, visible=False) decoder_a = gr.Radio(["Default"], label="Decoder", value="Default", interactive=False, visible=False) with gr.Row(): topk_a = gr.Number(label="Top-k", value=250, interactive=True) topp_a = gr.Number(label="Top-p", value=0, interactive=True) temperature_a = gr.Number(label="Temperature", value=1.0, interactive=True) cfg_coef_a = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) with gr.Row(): submit_a = gr.Button("Generate", variant="primary") _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Column(): with gr.Tab("Output"): output_a = gr.Video(label="Generated Audio", scale=0) with gr.Row(): audio_only_a = gr.Audio(type="numpy", label="Audio Only", interactive=False) backup_only_a = gr.Audio(type="numpy", label="Backup Audio", interactive=False, visible=False) send_audio_a = gr.Button("Send to Input Audio") seed_used_a = gr.Number(label='Seed used', value=-1, interactive=False) download_a = gr.File(label="Generated Files", interactive=False) with gr.Tab("Wiki"): gr.Markdown( """ - **[Generate (button)]:** Generates the audio with the given settings and prompts. - **[Interrupt (button)]:** Stops the audio generation as soon as it can, providing an incomplete output. --- ### Generation Tab: #### Structure Prompts: This feature helps reduce repetetive prompts by allowing you to set global prompts that will be used for all prompt segments. - **[Structure Prompts (checkbox)]:** Enable/Disable the structure prompts feature. - **[Global Prompt (text)]:** Here write the prompt that you wish to be used for all prompt segments. #### Multi-Prompt: This feature allows you to control the audio, adding variation to different time segments. You have up to 10 prompt segments. the first prompt will always be 10s long the other prompts will be [10s - overlap]. for example if the overlap is 2s, each prompt segment will be 8s. - **[Prompt Segments (number)]:** Amount of unique prompt to generate throughout the audio generation. - **[Prompt/Input Text (prompt)]:** Here describe the audio you wish the model to generate. - **[Repeat (number)]:** Write how many times this prompt will repeat (instead of wasting another prompt segment on the same prompt). - **[Time (text)]:** The time of the prompt segment. - **[Calculate Timings (button)]:** Calculates the timings of the prompt segments. - **[Duration (number)]:** How long you want the generated audio to be (in seconds). - **[Overlap (number)]:** How much each new segment will reference the previous segment (in seconds). For example, if you choose 2s: Each new segment after the first one will reference the previous segment 2s and will generate only 8s of new audio. The model can only process 10s of music. - **[Seed (number)]:** Your generated audio id. If you wish to generate the exact same audio, place the exact seed with the exact prompts (This way you can also extend specific song that was generated short). - **[Random Seed (button)]:** Gives "-1" as a seed, which counts as a random seed. - **[Copy Previous Seed (button)]:** Copies the seed from the output seed (if you don't feel like doing it manualy). --- ### Audio Tab: - **[Input Type (selection)]:** `File` mode allows you to upload an audio file to use as input `Mic` mode allows you to use your microphone as input - **[Trim Start and Trim End (numbers)]:** `Trim Start` set how much you'd like to trim the input audio from the start `Trim End` same as the above but from the end - **[Input Audio (audio file)]:** Input here the audio you wish to use. --- ### Customization Tab: - **[Background Color (color)]:** Works only if you don't upload image. Color of the background of the waveform. - **[Bar Color Start (color)]:** First color of the waveform bars. - **[Bar Color End (color)]:** Second color of the waveform bars. - **[Background Image (image)]:** Background image that you wish to be attached to the generated video along with the waveform. - **[Height and Width (numbers)]:** Output video resolution, only works with image. (minimum height and width is 256). --- ### Settings Tab: - **[Output Audio Channels (selection)]:** With this you can select the amount of channels that you wish for your output audio. `mono` is a straightforward single channel audio `stereo` is a dual channel audio but it will sound more or less like mono `stereo effect` this one is also dual channel but uses tricks to simulate a stereo audio. - **[Output Audio Sample Rate (dropdown)]:** The output audio sample rate, the model default is 32000. - **[Top-k (number)]:** is a parameter used in text generation models, including music generation models. It determines the number of most likely next tokens to consider at each step of the generation process. The model ranks all possible tokens based on their predicted probabilities, and then selects the top-k tokens from the ranked list. The model then samples from this reduced set of tokens to determine the next token in the generated sequence. A smaller value of k results in a more focused and deterministic output, while a larger value of k allows for more diversity in the generated music. - **[Top-p (number)]:** also known as nucleus sampling or probabilistic sampling, is another method used for token selection during text generation. Instead of specifying a fixed number like top-k, top-p considers the cumulative probability distribution of the ranked tokens. It selects the smallest possible set of tokens whose cumulative probability exceeds a certain threshold (usually denoted as p). The model then samples from this set to choose the next token. This approach ensures that the generated output maintains a balance between diversity and coherence, as it allows for a varying number of tokens to be considered based on their probabilities. - **[Temperature (number)]:** is a parameter that controls the randomness of the generated output. It is applied during the sampling process, where a higher temperature value results in more random and diverse outputs, while a lower temperature value leads to more deterministic and focused outputs. In the context of music generation, a higher temperature can introduce more variability and creativity into the generated music, but it may also lead to less coherent or structured compositions. On the other hand, a lower temperature can produce more repetitive and predictable music. - **[Classifier Free Guidance (number)]:** refers to a technique used in some music generation models where a separate classifier network is trained to provide guidance or control over the generated music. This classifier is trained on labeled data to recognize specific musical characteristics or styles. During the generation process, the output of the generator model is evaluated by the classifier, and the generator is encouraged to produce music that aligns with the desired characteristics or style. This approach allows for more fine-grained control over the generated music, enabling users to specify certain attributes they want the model to capture. """ ) with gr.Tab("Audio Info"): gr.Markdown( """ ### Audio Info """ ) with gr.Row(): with gr.Column(): in_audio = gr.File(type="file", label="Input Any Audio", interactive=True) with gr.Row(): send_gen = gr.Button("Send to MusicGen", variant="primary") send_gen_a = gr.Button("Send to AudioGen", variant="primary") with gr.Column(): info = gr.Textbox(label="Audio Info", lines=10, interactive=False) with gr.Tab("Changelog"): gr.Markdown( """ ## Changelog: ### v2.0.1 - Changed custom model loading to support the official trained models - Additional changes from the main facebookresearch repo ### v2.0.0a - Forgot to move all the update to app.py from temp2.py... oops ### v2.0.0 - Changed name from MusicGen+ to AudioCraft Plus - Complete overhaul of the repo "backend" with the latest changes from the main facebookresearch repo - Added a new decoder: MultiBand_Diffusion - Added AudioGen: a new tab for generating audio ### v1.2.8c - Implemented Reverse compatibility for audio info tab with previous versions ### v1.2.8b - Fixed the error when loading default models ### v1.2.8a - Adapted Audio info tab to work with the new structure prompts feature - Now custom models actually work, make sure you select the correct base model ### v1.2.8 - Now you will also recieve json file with metadata of generated audio - Added error messages in Audio Info tab - Added structure prompts: you can select bpm, key and global prompt for all prompts - Added time display next to each prompt, can be calculated with "Calculate Timings" button ### v1.2.7 - When sending generated audio to Input Audio, it will send a backup audio with default settings (best for continuos generation) - Added Metadata to generated audio (Thanks to AlexHK ♥) - Added Audio Info tab that will display the metadata of the input audio - Added "send to Text2Audio" button in Audio Info tab - Generated audio is now stored in the "output" folder (Thanks to AlexHK ♥) - Added an output area with generated files and download buttons - Enhanced Stereo effect (Thanks to AlexHK ♥) ### v1.2.6 - Added option to generate in stereo (instead of only mono) - Added dropdown for selecting output sample rate (model default is 32000) ### v1.2.5a - Added file cleaner (This comes from the main facebookresearch repo) - Reorganized a little, moved audio to a seperate tab ### v1.2.5 - Gave a unique lime theme to the webui - Added additional output for audio only - Added button to send generated audio to Input Audio - Added option to trim Input Audio ### v1.2.4 - Added mic input (This comes from the main facebookresearch repo) ### v1.2.3 - Added option to change video size to fit the image you upload ### v1.2.2 - Added Wiki, Changelog and About tabs ### v1.2.1 - Added tabs and organized the entire interface - Added option to attach image to the output video - Added option to load fine-tuned models (Yet to be tested) ### v1.2.0 - Added Multi-Prompt ### v1.1.3 - Added customization options for generated waveform ### v1.1.2 - Removed sample length limit: now you can input audio of any length as music sample ### v1.1.1 - Improved music sample audio quality when using music continuation ### v1.1.0 - Rebuilt the repo on top of the latest structure of the main MusicGen repo - Improved Music continuation feature ### v1.0.0 - Stable Version - Added Music continuation """ ) with gr.Tab("About"): gen_type = gr.Text(value="music", interactive=False, visible=False) gen_type_a = gr.Text(value="audio", interactive=False, visible=False) gr.Markdown( """ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284) ## MusicGen+ is an extended version of the original MusicGen by facebookresearch. ### Repo: https://github.com/GrandaddyShmax/audiocraft_plus/tree/plus --- ### This project was possible thanks to: #### GrandaddyShmax - https://github.com/GrandaddyShmax #### Camenduru - https://github.com/camenduru #### rkfg - https://github.com/rkfg #### oobabooga - https://github.com/oobabooga #### AlexHK - https://github.com/alanhk147 """ ) send_gen.click(info_to_params, inputs=[in_audio], outputs=[decoder, struc_prompts, global_prompt, bpm, key, scale, model, dropdown, s, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], mode, duration, topk, topp, temperature, cfg_coef, seed, overlap, channel, sr_select], queue=False) reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False) send_audio.click(fn=lambda x: x, inputs=[backup_only], outputs=[audio], queue=False) submit.click(predict_full, inputs=[gen_type, model, decoder, dropdown, s, struc_prompts, bpm, key, scale, global_prompt, prompts[0], prompts[1], prompts[2], prompts[3], prompts[4], prompts[5], prompts[6], prompts[7], prompts[8], prompts[9], repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9], audio, mode, trim_start, trim_end, duration, topk, topp, temperature, cfg_coef, seed, overlap, image, height, width, background, bar1, bar2, channel, sr_select], outputs=[output, audio_only, backup_only, download, seed_used]) input_type.change(toggle_audio_src, input_type, [audio], queue=False, show_progress=False) to_calc.click(calc_time, inputs=[gen_type, s, duration, overlap, repeats[0], repeats[1], repeats[2], repeats[3], repeats[4], repeats[5], repeats[6], repeats[7], repeats[8], repeats[9]], outputs=[calcs[0], calcs[1], calcs[2], calcs[3], calcs[4], calcs[5], calcs[6], calcs[7], calcs[8], calcs[9]], queue=False) send_gen_a.click(info_to_params_a, inputs=[in_audio], outputs=[decoder_a, struc_prompts_a, global_prompt_a, s_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, channel_a, sr_select_a], queue=False) reuse_seed_a.click(fn=lambda x: x, inputs=[seed_used_a], outputs=[seed_a], queue=False) send_audio_a.click(fn=lambda x: x, inputs=[backup_only_a], outputs=[audio_a], queue=False) submit_a.click(predict_full, inputs=[gen_type_a, model_a, decoder_a, dropdown, s_a, struc_prompts_a, bpm, key, scale, global_prompt_a, prompts_a[0], prompts_a[1], prompts_a[2], prompts_a[3], prompts_a[4], prompts_a[5], prompts_a[6], prompts_a[7], prompts_a[8], prompts_a[9], repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9], audio_a, mode_a, trim_start_a, trim_end_a, duration_a, topk_a, topp_a, temperature_a, cfg_coef_a, seed_a, overlap_a, image_a, height_a, width_a, background_a, bar1_a, bar2_a, channel_a, sr_select_a], outputs=[output_a, audio_only_a, backup_only_a, download_a, seed_used_a]) input_type_a.change(toggle_audio_src, input_type_a, [audio_a], queue=False, show_progress=False) to_calc_a.click(calc_time, inputs=[gen_type_a, s_a, duration_a, overlap_a, repeats_a[0], repeats_a[1], repeats_a[2], repeats_a[3], repeats_a[4], repeats_a[5], repeats_a[6], repeats_a[7], repeats_a[8], repeats_a[9]], outputs=[calcs_a[0], calcs_a[1], calcs_a[2], calcs_a[3], calcs_a[4], calcs_a[5], calcs_a[6], calcs_a[7], calcs_a[8], calcs_a[9]], queue=False) in_audio.change(get_audio_info, in_audio, outputs=[info]) def variable_outputs(k): k = int(k) - 1 return [gr.Textbox.update(visible=True)]*k + [gr.Textbox.update(visible=False)]*(max_textboxes-k) def get_size(image): if image is not None: img = Image.open(image) img_height = img.height img_width = img.width if (img_height%2) != 0: img_height = img_height + 1 if (img_width%2) != 0: img_width = img_width + 1 return img_height, img_width else: return 512, 768 image.change(get_size, image, outputs=[height, width]) image_a.change(get_size, image_a, outputs=[height_a, width_a]) s.change(variable_outputs, s, textboxes) s_a.change(variable_outputs, s_a, textboxes_a) interface.queue().launch(**launch_kwargs) def ui_batched(launch_kwargs): with gr.Blocks() as demo: gr.Markdown( """ # MusicGen This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
Duplicate Space for longer sequences, more control and no queue.

""" ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="Describe your music", lines=2, interactive=True) with gr.Column(): radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic") melody = gr.Audio(source="upload", type="numpy", label="File", interactive=True, elem_id="melody-input") with gr.Row(): submit = gr.Button("Generate") with gr.Column(): output = gr.Video(label="Generated Music") audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') submit.click(predict_batched, inputs=[text, melody], outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) gr.Examples( fn=predict_batched, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", ], [ "90s rock song with electric guitar and heavy drums", None, ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", "./assets/bach.mp3", ], [ "lofi slow bpm electro chill with organic samples", None, ], ], inputs=[text, melody], outputs=[output] ) gr.Markdown(""" ### More details The model will generate 12 seconds of audio based on the description you provided. You can optionally provide a reference audio from which a broad melody will be extracted. The model will then try to follow both the description and melody provided. All samples are generated with the `melody` model. You can also use your own GPU or a Google Colab by following the instructions on our repo. See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) for more details. """) demo.queue(max_size=8 * 4).launch(**launch_kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( '--listen', type=str, default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', help='IP to listen on for connections to Gradio', ) parser.add_argument( '--username', type=str, default='', help='Username for authentication' ) parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) parser.add_argument( '--server_port', type=int, default=0, help='Port to run the server listener on', ) parser.add_argument( '--inbrowser', action='store_true', help='Open in browser' ) parser.add_argument( '--share', action='store_true', help='Share the gradio UI' ) parser.add_argument( '--unload_model', action='store_true', help='Unload the model after every generation to save GPU memory' ) parser.add_argument( '--unload_to_cpu', action='store_true', help='Move the model to main RAM after every generation to save GPU memory but reload faster than after full unload (see above)' ) parser.add_argument( '--cache', action='store_true', help='Cache models in RAM to quickly switch between them' ) args = parser.parse_args() UNLOAD_MODEL = args.unload_model MOVE_TO_CPU = args.unload_to_cpu if args.cache: MODELS = {} launch_kwargs = {} launch_kwargs['server_name'] = args.listen if args.username and args.password: launch_kwargs['auth'] = (args.username, args.password) if args.server_port: launch_kwargs['server_port'] = args.server_port if args.inbrowser: launch_kwargs['inbrowser'] = args.inbrowser if args.share: launch_kwargs['share'] = args.share # Show the interface if IS_BATCHED: global USE_DIFFUSION USE_DIFFUSION = False ui_batched(launch_kwargs) else: ui_full(launch_kwargs) ================================================ FILE: audiocraft/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ AudioCraft is a general framework for training audio generative models. At the moment we provide the training code for: - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art text-to-music and melody+text autoregressive generative model. For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model, `audiocraft.models.musicgen.MusicGen`. - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art text-to-general-audio generative model. - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity neural audio codec which provides an excellent tokenizer for autoregressive language models. See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`. - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that improves the perceived quality and reduces the artifacts coming from adversarial decoders. """ # flake8: noqa from . import data, modules, models __version__ = '1.0.0' ================================================ FILE: audiocraft/adversarial/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Adversarial losses and discriminator architectures.""" # flake8: noqa from .discriminators import ( MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiScaleSTFTDiscriminator ) from .losses import ( AdversarialLoss, AdvLossType, get_adv_criterion, get_fake_criterion, get_real_criterion, FeatLossType, FeatureMatchingLoss ) ================================================ FILE: audiocraft/adversarial/discriminators/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # flake8: noqa from .mpd import MultiPeriodDiscriminator from .msd import MultiScaleDiscriminator from .msstftd import MultiScaleSTFTDiscriminator ================================================ FILE: audiocraft/adversarial/discriminators/base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod import typing as tp import torch import torch.nn as nn FeatureMapType = tp.List[torch.Tensor] LogitsType = torch.Tensor MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] class MultiDiscriminator(ABC, nn.Module): """Base implementation for discriminators composed of sub-discriminators acting at different scales. """ def __init__(self): super().__init__() @abstractmethod def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: ... @property @abstractmethod def num_discriminators(self) -> int: """Number of discriminators. """ ... ================================================ FILE: audiocraft/adversarial/discriminators/mpd.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import torch import torch.nn as nn import torch.nn.functional as F from ...modules import NormConv2d from .base import MultiDiscriminator, MultiDiscriminatorOutputType def get_padding(kernel_size: int, dilation: int = 1) -> int: return int((kernel_size * dilation - dilation) / 2) class PeriodDiscriminator(nn.Module): """Period sub-discriminator. Args: period (int): Period between samples of audio. in_channels (int): Number of input channels. out_channels (int): Number of output channels. n_layers (int): Number of convolutional layers. kernel_sizes (list of int): Kernel sizes for convolutions. stride (int): Stride for convolutions. filters (int): Initial number of filters in convolutions. filters_scale (int): Multiplier of number of filters as we increase depth. max_filters (int): Maximum number of filters. norm (str): Normalization method. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. """ def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1, n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3, filters: int = 8, filters_scale: int = 4, max_filters: int = 1024, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): super().__init__() self.period = period self.n_layers = n_layers self.activation = getattr(torch.nn, activation)(**activation_params) self.convs = nn.ModuleList() in_chs = in_channels for i in range(self.n_layers): out_chs = min(filters * (filters_scale ** (i + 1)), max_filters) eff_stride = 1 if i == self.n_layers - 1 else stride self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1), padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm)) in_chs = out_chs self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1, padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm) def forward(self, x: torch.Tensor): fmap = [] # 1d to 2d b, c, t = x.shape if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), 'reflect') t = t + n_pad x = x.view(b, c, t // self.period, self.period) for conv in self.convs: x = conv(x) x = self.activation(x) fmap.append(x) x = self.conv_post(x) fmap.append(x) # x = torch.flatten(x, 1, -1) return x, fmap class MultiPeriodDiscriminator(MultiDiscriminator): """Multi-Period (MPD) Discriminator. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. periods (Sequence[int]): Periods between samples of audio for the sub-discriminators. **kwargs: Additional args for `PeriodDiscriminator` """ def __init__(self, in_channels: int = 1, out_channels: int = 1, periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs): super().__init__() self.discriminators = nn.ModuleList([ PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods ]) @property def num_discriminators(self): return len(self.discriminators) def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: logits = [] fmaps = [] for disc in self.discriminators: logit, fmap = disc(x) logits.append(logit) fmaps.append(fmap) return logits, fmaps ================================================ FILE: audiocraft/adversarial/discriminators/msd.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import numpy as np import torch import torch.nn as nn from ...modules import NormConv1d from .base import MultiDiscriminator, MultiDiscriminatorOutputType class ScaleDiscriminator(nn.Module): """Waveform sub-discriminator. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions. filters (int): Number of initial filters for convolutions. max_filters (int): Maximum number of filters. downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions. inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions. groups (Sequence[int] or None): Groups for inner convolutions. strides (Sequence[int] or None): Strides for inner convolutions. paddings (Sequence[int] or None): Paddings for inner convolutions. norm (str): Normalization method. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. pad (str): Padding for initial convolution. pad_params (dict): Parameters to provide to the padding module. """ def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3], filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4], inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None, strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d', pad_params: dict = {}): super().__init__() assert len(kernel_sizes) == 2 assert kernel_sizes[0] % 2 == 1 assert kernel_sizes[1] % 2 == 1 assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales)) assert (groups is None or len(groups) == len(downsample_scales)) assert (strides is None or len(strides) == len(downsample_scales)) assert (paddings is None or len(paddings) == len(downsample_scales)) self.activation = getattr(torch.nn, activation)(**activation_params) self.convs = nn.ModuleList() self.convs.append( nn.Sequential( getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm) ) ) in_chs = filters for i, downsample_scale in enumerate(downsample_scales): out_chs = min(in_chs * downsample_scale, max_filters) default_kernel_size = downsample_scale * 10 + 1 default_stride = downsample_scale default_padding = (default_kernel_size - 1) // 2 default_groups = in_chs // 4 self.convs.append( NormConv1d(in_chs, out_chs, kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size, stride=strides[i] if strides else default_stride, groups=groups[i] if groups else default_groups, padding=paddings[i] if paddings else default_padding, norm=norm)) in_chs = out_chs out_chs = min(in_chs * 2, max_filters) self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1, padding=(kernel_sizes[0] - 1) // 2, norm=norm)) self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=(kernel_sizes[1] - 1) // 2, norm=norm) def forward(self, x: torch.Tensor): fmap = [] for layer in self.convs: x = layer(x) x = self.activation(x) fmap.append(x) x = self.conv_post(x) fmap.append(x) # x = torch.flatten(x, 1, -1) return x, fmap class MultiScaleDiscriminator(MultiDiscriminator): """Multi-Scale (MSD) Discriminator, Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. downsample_factor (int): Downsampling factor between the different scales. scale_norms (Sequence[str]): Normalization for each sub-discriminator. **kwargs: Additional args for ScaleDiscriminator. """ def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2, scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs): super().__init__() self.discriminators = nn.ModuleList([ ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms ]) self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor) @property def num_discriminators(self): return len(self.discriminators) def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: logits = [] fmaps = [] for i, disc in enumerate(self.discriminators): if i != 0: self.downsample(x) logit, fmap = disc(x) logits.append(logit) fmaps.append(fmap) return logits, fmaps ================================================ FILE: audiocraft/adversarial/discriminators/msstftd.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import torchaudio import torch from torch import nn from einops import rearrange from ...modules import NormConv2d from .base import MultiDiscriminator, MultiDiscriminatorOutputType def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)): return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2) class DiscriminatorSTFT(nn.Module): """STFT sub-discriminator. Args: filters (int): Number of filters in convolutions. in_channels (int): Number of input channels. out_channels (int): Number of output channels. n_fft (int): Size of FFT for each scale. hop_length (int): Length of hop between STFT windows for each scale. kernel_size (tuple of int): Inner Conv2d kernel sizes. stride (tuple of int): Inner Conv2d strides. dilations (list of int): Inner Conv2d dilation on the time dimension. win_length (int): Window size for each scale. normalized (bool): Whether to normalize by magnitude after stft. norm (str): Normalization method. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. growth (int): Growth factor for the filters. """ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024, filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4], stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm', activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}): super().__init__() assert len(kernel_size) == 2 assert len(stride) == 2 self.filters = filters self.in_channels = in_channels self.out_channels = out_channels self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.normalized = normalized self.activation = getattr(torch.nn, activation)(**activation_params) self.spec_transform = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window, normalized=self.normalized, center=False, pad_mode=None, power=None) spec_channels = 2 * self.in_channels self.convs = nn.ModuleList() self.convs.append( NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size)) ) in_chs = min(filters_scale * self.filters, max_filters) for i, dilation in enumerate(dilations): out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)), norm=norm)) in_chs = out_chs out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters) self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]), padding=get_2d_padding((kernel_size[0], kernel_size[0])), norm=norm)) self.conv_post = NormConv2d(out_chs, self.out_channels, kernel_size=(kernel_size[0], kernel_size[0]), padding=get_2d_padding((kernel_size[0], kernel_size[0])), norm=norm) def forward(self, x: torch.Tensor): fmap = [] z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] z = torch.cat([z.real, z.imag], dim=1) z = rearrange(z, 'b c w t -> b c t w') for i, layer in enumerate(self.convs): z = layer(z) z = self.activation(z) fmap.append(z) z = self.conv_post(z) return z, fmap class MultiScaleSTFTDiscriminator(MultiDiscriminator): """Multi-Scale STFT (MS-STFT) discriminator. Args: filters (int): Number of filters in convolutions. in_channels (int): Number of input channels. out_channels (int): Number of output channels. sep_channels (bool): Separate channels to distinct samples for stereo support. n_ffts (Sequence[int]): Size of FFT for each scale. hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale. win_lengths (Sequence[int]): Window size for each scale. **kwargs: Additional args for STFTDiscriminator. """ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False, n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128], win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs): super().__init__() assert len(n_ffts) == len(hop_lengths) == len(win_lengths) self.sep_channels = sep_channels self.discriminators = nn.ModuleList([ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels, n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs) for i in range(len(n_ffts)) ]) @property def num_discriminators(self): return len(self.discriminators) def _separate_channels(self, x: torch.Tensor) -> torch.Tensor: B, C, T = x.shape return x.view(-1, 1, T) def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: logits = [] fmaps = [] for disc in self.discriminators: logit, fmap = disc(x) logits.append(logit) fmaps.append(fmap) return logits, fmaps ================================================ FILE: audiocraft/adversarial/losses.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Utility module to handle adversarial losses without requiring to mess up the main training loop. """ import typing as tp import flashy import torch import torch.nn as nn import torch.nn.functional as F ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2'] AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]] FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] class AdversarialLoss(nn.Module): """Adversary training wrapper. Args: adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples. We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]`` where the first item is a list of logits and the second item is a list of feature maps. optimizer (torch.optim.Optimizer): Optimizer used for training the given module. loss (AdvLossType): Loss function for generator training. loss_real (AdvLossType): Loss function for adversarial training on logits from real samples. loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples. loss_feat (FeatLossType): Feature matching loss function for generator training. normalize (bool): Whether to normalize by number of sub-discriminators. Example of usage: adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake) for real in loader: noise = torch.randn(...) fake = model(noise) adv_loss.train_adv(fake, real) loss, _ = adv_loss(fake, real) loss.backward() """ def __init__(self, adversary: nn.Module, optimizer: torch.optim.Optimizer, loss: AdvLossType, loss_real: AdvLossType, loss_fake: AdvLossType, loss_feat: tp.Optional[FeatLossType] = None, normalize: bool = True): super().__init__() self.adversary: nn.Module = adversary flashy.distrib.broadcast_model(self.adversary) self.optimizer = optimizer self.loss = loss self.loss_real = loss_real self.loss_fake = loss_fake self.loss_feat = loss_feat self.normalize = normalize def _save_to_state_dict(self, destination, prefix, keep_vars): # Add the optimizer state dict inside our own. super()._save_to_state_dict(destination, prefix, keep_vars) destination[prefix + 'optimizer'] = self.optimizer.state_dict() return destination def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # Load optimizer state. self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer')) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def get_adversary_pred(self, x): """Run adversary model, validating expected output format.""" logits, fmaps = self.adversary(x) assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \ f'Expecting a list of tensors as logits but {type(logits)} found.' assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.' for fmap in fmaps: assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \ f'Expecting a list of tensors as feature maps but {type(fmap)} found.' return logits, fmaps def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor: """Train the adversary with the given fake and real example. We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]]. The first item being the logits and second item being a list of feature maps for each sub-discriminator. This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`) and call the optimizer. """ loss = torch.tensor(0., device=fake.device) all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach()) all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach()) n_sub_adversaries = len(all_logits_fake_is_fake) for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake): loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake) if self.normalize: loss /= n_sub_adversaries self.optimizer.zero_grad() with flashy.distrib.eager_sync_model(self.adversary): loss.backward() self.optimizer.step() return loss def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Return the loss for the generator, i.e. trying to fool the adversary, and feature matching loss if provided. """ adv = torch.tensor(0., device=fake.device) feat = torch.tensor(0., device=fake.device) with flashy.utils.readonly(self.adversary): all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake) all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real) n_sub_adversaries = len(all_logits_fake_is_fake) for logit_fake_is_fake in all_logits_fake_is_fake: adv += self.loss(logit_fake_is_fake) if self.loss_feat: for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real): feat += self.loss_feat(fmap_fake, fmap_real) if self.normalize: adv /= n_sub_adversaries feat /= n_sub_adversaries return adv, feat def get_adv_criterion(loss_type: str) -> tp.Callable: assert loss_type in ADVERSARIAL_LOSSES if loss_type == 'mse': return mse_loss elif loss_type == 'hinge': return hinge_loss elif loss_type == 'hinge2': return hinge2_loss raise ValueError('Unsupported loss') def get_fake_criterion(loss_type: str) -> tp.Callable: assert loss_type in ADVERSARIAL_LOSSES if loss_type == 'mse': return mse_fake_loss elif loss_type in ['hinge', 'hinge2']: return hinge_fake_loss raise ValueError('Unsupported loss') def get_real_criterion(loss_type: str) -> tp.Callable: assert loss_type in ADVERSARIAL_LOSSES if loss_type == 'mse': return mse_real_loss elif loss_type in ['hinge', 'hinge2']: return hinge_real_loss raise ValueError('Unsupported loss') def mse_real_loss(x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) def mse_fake_loss(x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x)) def hinge_real_loss(x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x))) def mse_loss(x: torch.Tensor) -> torch.Tensor: if x.numel() == 0: return torch.tensor([0.0], device=x.device) return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x)) def hinge_loss(x: torch.Tensor) -> torch.Tensor: if x.numel() == 0: return torch.tensor([0.0], device=x.device) return -x.mean() def hinge2_loss(x: torch.Tensor) -> torch.Tensor: if x.numel() == 0: return torch.tensor([0.0]) return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x))) class FeatureMatchingLoss(nn.Module): """Feature matching loss for adversarial training. Args: loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1). normalize (bool): Whether to normalize the loss. by number of feature maps. """ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True): super().__init__() self.loss = loss self.normalize = normalize def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor: assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0 feat_loss = torch.tensor(0., device=fmap_fake[0].device) feat_scale = torch.tensor(0., device=fmap_fake[0].device) n_fmaps = 0 for (feat_fake, feat_real) in zip(fmap_fake, fmap_real): assert feat_fake.shape == feat_real.shape n_fmaps += 1 feat_loss += self.loss(feat_fake, feat_real) feat_scale += torch.mean(torch.abs(feat_real)) if self.normalize: feat_loss /= n_fmaps return feat_loss ================================================ FILE: audiocraft/data/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Audio loading and writing support. Datasets for raw audio or also including some metadata.""" # flake8: noqa from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset ================================================ FILE: audiocraft/data/audio.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Audio IO methods are defined in this module (info, read, write), We rely on av library for faster read when possible, otherwise on torchaudio. """ from dataclasses import dataclass from pathlib import Path import logging import typing as tp import numpy as np import soundfile import torch from torch.nn import functional as F import torchaudio as ta import av from .audio_utils import f32_pcm, i16_pcm, normalize_audio _av_initialized = False def _init_av(): global _av_initialized if _av_initialized: return logger = logging.getLogger('libav.mp3') logger.setLevel(logging.ERROR) _av_initialized = True @dataclass(frozen=True) class AudioFileInfo: sample_rate: int duration: float channels: int def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: _init_av() with av.open(str(filepath)) as af: stream = af.streams.audio[0] sample_rate = stream.codec_context.sample_rate duration = float(stream.duration * stream.time_base) channels = stream.channels return AudioFileInfo(sample_rate, duration, channels) def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: info = soundfile.info(filepath) return AudioFileInfo(info.samplerate, info.duration, info.channels) def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: # torchaudio no longer returns useful duration informations for some formats like mp3s. filepath = Path(filepath) if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info # ffmpeg has some weird issue with flac. return _soundfile_info(filepath) else: return _av_info(filepath) def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]: """FFMPEG-based audio file reading using PyAV bindings. Soundfile cannot read mp3 and av_read is more efficient than torchaudio. Args: filepath (str or Path): Path to audio file to read. seek_time (float): Time at which to start reading in the file. duration (float): Duration to read from the file. If set to -1, the whole file is read. Returns: tuple of torch.Tensor, int: Tuple containing audio data and sample rate """ _init_av() with av.open(str(filepath)) as af: stream = af.streams.audio[0] sr = stream.codec_context.sample_rate num_frames = int(sr * duration) if duration >= 0 else -1 frame_offset = int(sr * seek_time) # we need a small negative offset otherwise we get some edge artifact # from the mp3 decoder. af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream) frames = [] length = 0 for frame in af.decode(streams=stream.index): current_offset = int(frame.rate * frame.pts * frame.time_base) strip = max(0, frame_offset - current_offset) buf = torch.from_numpy(frame.to_ndarray()) if buf.shape[0] != stream.channels: buf = buf.view(-1, stream.channels).t() buf = buf[:, strip:] frames.append(buf) length += buf.shape[1] if num_frames > 0 and length >= num_frames: break assert frames # If the above assert fails, it is likely because we seeked past the end of file point, # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp. # This will need proper debugging, in due time. wav = torch.cat(frames, dim=1) assert wav.shape[0] == stream.channels if num_frames > 0: wav = wav[:, :num_frames] return f32_pcm(wav), sr def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: """Read audio by picking the most appropriate backend tool based on the audio format. Args: filepath (str or Path): Path to audio file to read. seek_time (float): Time at which to start reading in the file. duration (float): Duration to read from the file. If set to -1, the whole file is read. pad (bool): Pad output audio if not reaching expected duration. Returns: tuple of torch.Tensor, int: Tuple containing audio data and sample rate. """ fp = Path(filepath) if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg # There is some bug with ffmpeg and reading flac info = _soundfile_info(filepath) frames = -1 if duration <= 0 else int(duration * info.sample_rate) frame_offset = int(seek_time * info.sample_rate) wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}" wav = torch.from_numpy(wav).t().contiguous() if len(wav.shape) == 1: wav = torch.unsqueeze(wav, 0) elif ( fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() and duration <= 0 and seek_time == 0 ): # Torchaudio is faster if we load an entire file at once. wav, sr = ta.load(fp) else: wav, sr = _av_read(filepath, seek_time, duration) if pad and duration > 0: expected_frames = int(duration * sr) wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) return wav, sr def audio_write(stem_name: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = True, make_parent_dir: bool = True, add_suffix: bool = True) -> Path: """Convenience function for saving audio to disk. Returns the filename the audio was written to. Args: stem_name (str or Path): Filename without extension which will be added automatically. format (str): Either "wav" or "mp3". mp3_rate (int): kbps when using mp3s. normalize (bool): if `True` (default), normalizes according to the prescribed strategy (see after). If `False`, the strategy is only used in case clipping would happen. strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square with extra headroom to avoid clipping. 'clip' just clips. peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger than the `peak_clip` one to avoid further clipping. loudness_headroom_db (float): Target loudness for loudness normalization. loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still occurs despite strategy (only for 'rms'). make_parent_dir (bool): Make parent directory if it doesn't exist. Returns: Path: Path of the saved audio. """ assert wav.dtype.is_floating_point, "wav is not floating point" if wav.dim() == 1: wav = wav[None] elif wav.dim() > 2: raise ValueError("Input wav should be at most 2 dimension.") assert wav.isfinite().all() wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, rms_headroom_db, loudness_headroom_db, loudness_compressor, log_clipping=log_clipping, sample_rate=sample_rate, stem_name=str(stem_name)) kwargs: dict = {} if format == 'mp3': suffix = '.mp3' kwargs.update({"compression": mp3_rate}) elif format == 'wav': wav = i16_pcm(wav) suffix = '.wav' kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) else: raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") if not add_suffix: suffix = '' path = Path(str(stem_name) + suffix) if make_parent_dir: path.parent.mkdir(exist_ok=True, parents=True) try: ta.save(path, wav, sample_rate, **kwargs) except Exception: if path.exists(): # we do not want to leave half written files around. path.unlink() raise return path ================================================ FILE: audiocraft/data/audio_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """AudioDataset support. In order to handle a larger number of files without having to scan again the folders, we precompute some metadata (filename, sample rate, duration), and use that to efficiently sample audio segments. """ import argparse import copy from concurrent.futures import ThreadPoolExecutor, Future from dataclasses import dataclass, fields from contextlib import ExitStack from functools import lru_cache import gzip import json import logging import os from pathlib import Path import random import sys import typing as tp import torch import torch.nn.functional as F from .audio import audio_read, audio_info from .audio_utils import convert_audio from .zip import PathInZip try: import dora except ImportError: dora = None # type: ignore @dataclass(order=True) class BaseInfo: @classmethod def _dict2fields(cls, dictionary: dict): return { field.name: dictionary[field.name] for field in fields(cls) if field.name in dictionary } @classmethod def from_dict(cls, dictionary: dict): _dictionary = cls._dict2fields(dictionary) return cls(**_dictionary) def to_dict(self): return { field.name: self.__getattribute__(field.name) for field in fields(self) } @dataclass(order=True) class AudioMeta(BaseInfo): path: str duration: float sample_rate: int amplitude: tp.Optional[float] = None weight: tp.Optional[float] = None # info_path is used to load additional information about the audio file that is stored in zip files. info_path: tp.Optional[PathInZip] = None @classmethod def from_dict(cls, dictionary: dict): base = cls._dict2fields(dictionary) if 'info_path' in base and base['info_path'] is not None: base['info_path'] = PathInZip(base['info_path']) return cls(**base) def to_dict(self): d = super().to_dict() if d['info_path'] is not None: d['info_path'] = str(d['info_path']) return d @dataclass(order=True) class SegmentInfo(BaseInfo): meta: AudioMeta seek_time: float # The following values are given once the audio is processed, e.g. # at the target sample rate and target number of channels. n_frames: int # actual number of frames without padding total_frames: int # total number of frames, padding included sample_rate: int # actual sample rate channels: int # number of audio channels. DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] logger = logging.getLogger(__name__) def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: """AudioMeta from a path to an audio file. Args: file_path (str): Resolved path of valid audio file. minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). Returns: AudioMeta: Audio file path and its metadata. """ info = audio_info(file_path) amplitude: tp.Optional[float] = None if not minimal: wav, sr = audio_read(file_path) amplitude = wav.abs().max().item() return AudioMeta(file_path, info.duration, info.sample_rate, amplitude) def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: """If Dora is available as a dependency, try to resolve potential relative paths in list of AudioMeta. This method is expected to be used when loading meta from file. Args: m (AudioMeta): Audio meta to resolve. fast (bool): If True, uses a really fast check for determining if a file is already absolute or not. Only valid on Linux/Mac. Returns: AudioMeta: Audio meta with resolved path. """ def is_abs(m): if fast: return str(m)[0] == '/' else: os.path.isabs(str(m)) if not dora: return m if not is_abs(m.path): m.path = dora.git_save.to_absolute_path(m.path) if m.info_path is not None and not is_abs(m.info_path.zip_path): m.info_path.zip_path = dora.git_save.to_absolute_path(m.path) return m def find_audio_files(path: tp.Union[Path, str], exts: tp.List[str] = DEFAULT_EXTS, resolve: bool = True, minimal: bool = True, progress: bool = False, workers: int = 0) -> tp.List[AudioMeta]: """Build a list of AudioMeta from a given path, collecting relevant audio files and fetching meta info. Args: path (str or Path): Path to folder containing audio files. exts (list of str): List of file extensions to consider for audio files. minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). progress (bool): Whether to log progress on audio files collection. workers (int): number of parallel workers, if 0, use only the current thread. Returns: list of AudioMeta: List of audio file path and its metadata. """ audio_files = [] futures: tp.List[Future] = [] pool: tp.Optional[ThreadPoolExecutor] = None with ExitStack() as stack: if workers > 0: pool = ThreadPoolExecutor(workers) stack.enter_context(pool) if progress: print("Finding audio files...") for root, folders, files in os.walk(path, followlinks=True): for file in files: full_path = Path(root) / file if full_path.suffix.lower() in exts: audio_files.append(full_path) if pool is not None: futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal)) if progress: print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr) if progress: print("Getting audio metadata...") meta: tp.List[AudioMeta] = [] for idx, file_path in enumerate(audio_files): try: if pool is None: m = _get_audio_meta(str(file_path), minimal) else: m = futures[idx].result() if resolve: m = _resolve_audio_meta(m) except Exception as err: print("Error with", str(file_path), err, file=sys.stderr) continue meta.append(m) if progress: print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) meta.sort() return meta def load_audio_meta(path: tp.Union[str, Path], resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]: """Load list of AudioMeta from an optionally compressed json file. Args: path (str or Path): Path to JSON file. resolve (bool): Whether to resolve the path from AudioMeta (default=True). fast (bool): activates some tricks to make things faster. Returns: list of AudioMeta: List of audio file path and its total duration. """ open_fn = gzip.open if str(path).lower().endswith('.gz') else open with open_fn(path, 'rb') as fp: # type: ignore lines = fp.readlines() meta = [] for line in lines: d = json.loads(line) m = AudioMeta.from_dict(d) if resolve: m = _resolve_audio_meta(m, fast=fast) meta.append(m) return meta def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): """Save the audio metadata to the file pointer as json. Args: path (str or Path): Path to JSON file. metadata (list of BaseAudioMeta): List of audio meta to save. """ Path(path).parent.mkdir(exist_ok=True, parents=True) open_fn = gzip.open if str(path).lower().endswith('.gz') else open with open_fn(path, 'wb') as fp: # type: ignore for m in meta: json_str = json.dumps(m.to_dict()) + '\n' json_bytes = json_str.encode('utf-8') fp.write(json_bytes) class AudioDataset: """Base audio dataset. The dataset takes a list of AudioMeta and create a dataset composed of segments of audio and potentially additional information, by creating random segments from the list of audio files referenced in the metadata and applying minimal data pre-processing such as resampling, mixing of channels, padding, etc. If no segment_duration value is provided, the AudioDataset will return the full wav for each audio file. Otherwise, it will randomly sample audio files and create a segment of the specified duration, applying padding if required. By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True allows to return a tuple containing the torch Tensor and additional metadata on the segment and the original audio meta. Note that you can call `start_epoch(epoch)` in order to get a deterministic "randomization" for `shuffle=True`. For a given epoch and dataset index, this will always return the same extract. You can get back some diversity by setting the `shuffle_seed` param. Args: meta (list of AudioMeta): List of audio files metadata. segment_duration (float, optional): Optional segment duration of audio to load. If not specified, the dataset will load the full audio segment from the file. shuffle (bool): Set to `True` to have the data reshuffled at every epoch. sample_rate (int): Target sample rate of the loaded audio samples. channels (int): Target number of channels of the loaded audio samples. sample_on_duration (bool): Set to `True` to sample segments with probability dependent on audio file duration. This is only used if `segment_duration` is provided. sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product of the file duration and file weight. This is only used if `segment_duration` is provided. min_segment_ratio (float): Minimum segment ratio to use when the audio file is shorter than the desired segment. max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided audio shorter than this will be filtered out. max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided audio longer than this will be filtered out. shuffle_seed (int): can be used to further randomize load_wav (bool): if False, skip loading the wav but returns a tensor of 0 with the expected segment_duration (which must be provided if load_wav is False). permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration` are False. Will ensure a permutation on files when going through the dataset. In that case the epoch number must be provided in order for the model to continue the permutation across epochs. In that case, it is assumed that `num_samples = total_batch_size * num_updates_per_epoch`, with `total_batch_size` the overall batch size accounting for all gpus. """ def __init__(self, meta: tp.List[AudioMeta], segment_duration: tp.Optional[float] = None, shuffle: bool = True, num_samples: int = 10_000, sample_rate: int = 48_000, channels: int = 2, pad: bool = True, sample_on_duration: bool = True, sample_on_weight: bool = True, min_segment_ratio: float = 0.5, max_read_retry: int = 10, return_info: bool = False, min_audio_duration: tp.Optional[float] = None, max_audio_duration: tp.Optional[float] = None, shuffle_seed: int = 0, load_wav: bool = True, permutation_on_files: bool = False, ): assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta." assert segment_duration is None or segment_duration > 0 assert segment_duration is None or min_segment_ratio >= 0 self.segment_duration = segment_duration self.min_segment_ratio = min_segment_ratio self.max_audio_duration = max_audio_duration self.min_audio_duration = min_audio_duration if self.min_audio_duration is not None and self.max_audio_duration is not None: assert self.min_audio_duration <= self.max_audio_duration self.meta: tp.List[AudioMeta] = self._filter_duration(meta) assert len(self.meta) # Fail fast if all data has been filtered. self.total_duration = sum(d.duration for d in self.meta) if segment_duration is None: num_samples = len(self.meta) self.num_samples = num_samples self.shuffle = shuffle self.sample_rate = sample_rate self.channels = channels self.pad = pad self.sample_on_weight = sample_on_weight self.sample_on_duration = sample_on_duration self.sampling_probabilities = self._get_sampling_probabilities() self.max_read_retry = max_read_retry self.return_info = return_info self.shuffle_seed = shuffle_seed self.current_epoch: tp.Optional[int] = None self.load_wav = load_wav if not load_wav: assert segment_duration is not None self.permutation_on_files = permutation_on_files if permutation_on_files: assert not self.sample_on_duration assert not self.sample_on_weight assert self.shuffle def start_epoch(self, epoch: int): self.current_epoch = epoch def __len__(self): return self.num_samples def _get_sampling_probabilities(self, normalized: bool = True): """Return the sampling probabilities for each file inside `self.meta`.""" scores: tp.List[float] = [] for file_meta in self.meta: score = 1. if self.sample_on_weight and file_meta.weight is not None: score *= file_meta.weight if self.sample_on_duration: score *= file_meta.duration scores.append(score) probabilities = torch.tensor(scores) if normalized: probabilities /= probabilities.sum() return probabilities @staticmethod @lru_cache(16) def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int): # Used to keep the most recent files permutation in memory implicitely. # will work unless someone is using a lot of Datasets in parallel. rng = torch.Generator() rng.manual_seed(base_seed + permutation_index) return torch.randperm(num_files, generator=rng) def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta: """Sample a given file from `self.meta`. Can be overridden in subclasses. This is only called if `segment_duration` is not None. You must use the provided random number generator `rng` for reproducibility. You can further make use of the index accessed. """ if self.permutation_on_files: assert self.current_epoch is not None total_index = self.current_epoch * len(self) + index permutation_index = total_index // len(self.meta) relative_index = total_index % len(self.meta) permutation = AudioDataset._get_file_permutation( len(self.meta), permutation_index, self.shuffle_seed) file_index = permutation[relative_index] return self.meta[file_index] if not self.sample_on_weight and not self.sample_on_duration: file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) else: file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item()) return self.meta[file_index] def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1): # Override this method in subclass if needed. if self.load_wav: return audio_read(path, seek_time, duration, pad=False) else: assert self.segment_duration is not None n_frames = int(self.sample_rate * self.segment_duration) return torch.zeros(self.channels, n_frames), self.sample_rate def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: if self.segment_duration is None: file_meta = self.meta[index] out, sr = audio_read(file_meta.path) out = convert_audio(out, sr, self.sample_rate, self.channels) n_frames = out.shape[-1] segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, sample_rate=self.sample_rate, channels=out.shape[0]) else: rng = torch.Generator() if self.shuffle: # We use index, plus extra randomness, either totally random if we don't know the epoch. # otherwise we make use of the epoch number and optional shuffle_seed. if self.current_epoch is None: rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) else: rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed)) else: # We only use index rng.manual_seed(index) for retry in range(self.max_read_retry): file_meta = self.sample_file(index, rng) # We add some variance in the file position even if audio file is smaller than segment # without ending up with empty segments max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) seek_time = torch.rand(1, generator=rng).item() * max_seek try: out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False) out = convert_audio(out, sr, self.sample_rate, self.channels) n_frames = out.shape[-1] target_frames = int(self.segment_duration * self.sample_rate) if self.pad: out = F.pad(out, (0, target_frames - n_frames)) segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, sample_rate=self.sample_rate, channels=out.shape[0]) except Exception as exc: logger.warning("Error opening file %s: %r", file_meta.path, exc) if retry == self.max_read_retry - 1: raise else: break if self.return_info: # Returns the wav and additional information on the wave segment return out, segment_info else: return out def collater(self, samples): """The collater function has to be provided to the dataloader if AudioDataset has return_info=True in order to properly collate the samples of a batch. """ if self.segment_duration is None and len(samples) > 1: assert self.pad, "Must allow padding when batching examples of different durations." # In this case the audio reaching the collater is of variable length as segment_duration=None. to_pad = self.segment_duration is None and self.pad if to_pad: max_len = max([wav.shape[-1] for wav, _ in samples]) def _pad_wav(wav): return F.pad(wav, (0, max_len - wav.shape[-1])) if self.return_info: if len(samples) > 0: assert len(samples[0]) == 2 assert isinstance(samples[0][0], torch.Tensor) assert isinstance(samples[0][1], SegmentInfo) wavs = [wav for wav, _ in samples] segment_infos = [copy.deepcopy(info) for _, info in samples] if to_pad: # Each wav could be of a different duration as they are not segmented. for i in range(len(samples)): # Determines the total length of the signal with padding, so we update here as we pad. segment_infos[i].total_frames = max_len wavs[i] = _pad_wav(wavs[i]) wav = torch.stack(wavs) return wav, segment_infos else: assert isinstance(samples[0], torch.Tensor) if to_pad: samples = [_pad_wav(s) for s in samples] return torch.stack(samples) def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: """Filters out audio files with audio durations that will not allow to sample examples from them.""" orig_len = len(meta) # Filter data that is too short. if self.min_audio_duration is not None: meta = [m for m in meta if m.duration >= self.min_audio_duration] # Filter data that is too long. if self.max_audio_duration is not None: meta = [m for m in meta if m.duration <= self.max_audio_duration] filtered_len = len(meta) removed_percentage = 100*(1-float(filtered_len)/orig_len) msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage if removed_percentage < 10: logging.debug(msg) else: logging.warning(msg) return meta @classmethod def from_meta(cls, root: tp.Union[str, Path], **kwargs): """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. Args: root (str or Path): Path to root folder containing audio files. kwargs: Additional keyword arguments for the AudioDataset. """ root = Path(root) if root.is_dir(): if (root / 'data.jsonl').exists(): root = root / 'data.jsonl' elif (root / 'data.jsonl.gz').exists(): root = root / 'data.jsonl.gz' else: raise ValueError("Don't know where to read metadata from in the dir. " "Expecting either a data.jsonl or data.jsonl.gz file but none found.") meta = load_audio_meta(root) return cls(meta, **kwargs) @classmethod def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, exts: tp.List[str] = DEFAULT_EXTS, **kwargs): """Instantiate AudioDataset from a path containing (possibly nested) audio files. Args: root (str or Path): Path to root folder containing audio files. minimal_meta (bool): Whether to only load minimal metadata or not. exts (list of str): Extensions for audio files. kwargs: Additional keyword arguments for the AudioDataset. """ root = Path(root) if root.is_file(): meta = load_audio_meta(root, resolve=True) else: meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True) return cls(meta, **kwargs) def main(): logging.basicConfig(stream=sys.stderr, level=logging.INFO) parser = argparse.ArgumentParser( prog='audio_dataset', description='Generate .jsonl files by scanning a folder.') parser.add_argument('root', help='Root folder with all the audio files') parser.add_argument('output_meta_file', help='Output file to store the metadata, ') parser.add_argument('--complete', action='store_false', dest='minimal', default=True, help='Retrieve all metadata, even the one that are expansive ' 'to compute (e.g. normalization).') parser.add_argument('--resolve', action='store_true', default=False, help='Resolve the paths to be absolute and with no symlinks.') parser.add_argument('--workers', default=10, type=int, help='Number of workers.') args = parser.parse_args() meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True, resolve=args.resolve, minimal=args.minimal, workers=args.workers) save_audio_meta(args.output_meta_file, meta) if __name__ == '__main__': main() ================================================ FILE: audiocraft/data/audio_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Various utilities for audio convertion (pcm format, sample rate and channels), and volume normalization.""" import sys import typing as tp import julius import torch import torchaudio def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor: """Convert audio to the given number of channels. Args: wav (torch.Tensor): Audio wave of shape [B, C, T]. channels (int): Expected number of channels as output. Returns: torch.Tensor: Downmixed or unchanged audio wave [B, C, T]. """ *shape, src_channels, length = wav.shape if src_channels == channels: pass elif channels == 1: # Case 1: # The caller asked 1-channel audio, and the stream has multiple # channels, downmix all channels. wav = wav.mean(dim=-2, keepdim=True) elif src_channels == 1: # Case 2: # The caller asked for multiple channels, but the input file has # a single channel, replicate the audio over all channels. wav = wav.expand(*shape, channels, length) elif src_channels >= channels: # Case 3: # The caller asked for multiple channels, and the input file has # more channels than requested. In that case return the first channels. wav = wav[..., :channels, :] else: # Case 4: What is a reasonable choice here? raise ValueError('The audio file has less channels than requested but is not mono.') return wav def convert_audio(wav: torch.Tensor, from_rate: float, to_rate: float, to_channels: int) -> torch.Tensor: """Convert audio to new sample rate and number of audio channels.""" wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) wav = convert_audio_channels(wav, to_channels) return wav def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, loudness_compressor: bool = False, energy_floor: float = 2e-3): """Normalize an input signal to a user loudness in dB LKFS. Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. Args: wav (torch.Tensor): Input multichannel audio data. sample_rate (int): Sample rate. loudness_headroom_db (float): Target loudness of the output in dB LUFS. loudness_compressor (bool): Uses tanh for soft clipping. energy_floor (float): anything below that RMS level will not be rescaled. Returns: torch.Tensor: Loudness normalized output data. """ energy = wav.pow(2).mean().sqrt().item() if energy < energy_floor: return wav transform = torchaudio.transforms.Loudness(sample_rate) input_loudness_db = transform(wav).item() # calculate the gain needed to scale to the desired loudness level delta_loudness = -loudness_headroom_db - input_loudness_db gain = 10.0 ** (delta_loudness / 20.0) output = gain * wav if loudness_compressor: output = torch.tanh(output) assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) return output def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None: """Utility function to clip the audio with logging if specified.""" max_scale = wav.abs().max() if log_clipping and max_scale > 1: clamp_prob = (wav.abs() > 1).float().mean().item() print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):", clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr) #wav.clamp_(-1, 1) wav = wav.clone().clamp_(-1, 1) def normalize_audio(wav: torch.Tensor, normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = False, sample_rate: tp.Optional[int] = None, stem_name: tp.Optional[str] = None) -> torch.Tensor: """Normalize the audio according to the prescribed strategy (see after). Args: wav (torch.Tensor): Audio data. normalize (bool): if `True` (default), normalizes according to the prescribed strategy (see after). If `False`, the strategy is only used in case clipping would happen. strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square with extra headroom to avoid clipping. 'clip' just clips. peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger than the `peak_clip` one to avoid further clipping. loudness_headroom_db (float): Target loudness for loudness normalization. loudness_compressor (bool): If True, uses tanh based soft clipping. log_clipping (bool): If True, basic logging on stderr when clipping still occurs despite strategy (only for 'rms'). sample_rate (int): Sample rate for the audio data (required for loudness). stem_name (str, optional): Stem name for clipping logging. Returns: torch.Tensor: Normalized audio. """ scale_peak = 10 ** (-peak_clip_headroom_db / 20) scale_rms = 10 ** (-rms_headroom_db / 20) if strategy == 'peak': rescaling = (scale_peak / wav.abs().max()) if normalize or rescaling < 1: wav = wav * rescaling elif strategy == 'clip': wav = wav.clamp(-scale_peak, scale_peak) elif strategy == 'rms': mono = wav.mean(dim=0) rescaling = scale_rms / mono.pow(2).mean().sqrt() if normalize or rescaling < 1: wav = wav * rescaling _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) elif strategy == 'loudness': assert sample_rate is not None, "Loudness normalization requires sample rate." wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) else: assert wav.abs().max() < 1 assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" return wav def f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format. """ if wav.dtype.is_floating_point: return wav elif wav.dtype == torch.int16: return wav.float() / 2**15 elif wav.dtype == torch.int32: return wav.float() / 2**31 raise ValueError(f"Unsupported wav dtype: {wav.dtype}") def i16_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to int 16 bits PCM format. ..Warning:: There exist many formula for doing this conversion. None are perfect due to the asymmetry of the int16 range. One either have possible clipping, DC offset, or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom, it is possible that `i16_pcm(f32_pcm)) != Identity`. """ if wav.dtype.is_floating_point: assert wav.abs().max() <= 1 candidate = (wav * 2 ** 15).round() if candidate.max() >= 2 ** 15: # clipping would occur candidate = (wav * (2 ** 15 - 1)).round() return candidate.short() else: assert wav.dtype == torch.int16 return wav ================================================ FILE: audiocraft/data/info_audio_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Base classes for the datasets that also provide non-audio metadata, e.g. description, text transcription etc. """ from dataclasses import dataclass import logging import math import re import typing as tp import torch from .audio_dataset import AudioDataset, AudioMeta from ..environment import AudioCraftEnvironment from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes logger = logging.getLogger(__name__) def _clusterify_meta(meta: AudioMeta) -> AudioMeta: """Monkey-patch meta to match cluster specificities.""" meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) if meta.info_path is not None: meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) return meta def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: """Monkey-patch all meta to match cluster specificities.""" return [_clusterify_meta(m) for m in meta] @dataclass class AudioInfo(SegmentWithAttributes): """Dummy SegmentInfo with empty attributes. The InfoAudioDataset is expected to return metadata that inherits from SegmentWithAttributes class and can return conditioning attributes. This basically guarantees all datasets will be compatible with current solver that contain conditioners requiring this. """ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. def to_condition_attributes(self) -> ConditioningAttributes: return ConditioningAttributes() class InfoAudioDataset(AudioDataset): """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. """ def __init__(self, meta: tp.List[AudioMeta], **kwargs): super().__init__(clusterify_all_meta(meta), **kwargs) def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: if not self.return_info: wav = super().__getitem__(index) assert isinstance(wav, torch.Tensor) return wav wav, meta = super().__getitem__(index) return wav, AudioInfo(**meta.to_dict()) def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: """Preprocess a single keyword or possible a list of keywords.""" if isinstance(value, list): return get_keyword_list(value) else: return get_keyword(value) def get_string(value: tp.Optional[str]) -> tp.Optional[str]: """Preprocess a single keyword.""" if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': return None else: return value.strip() def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: """Preprocess a single keyword.""" if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': return None else: return value.strip().lower() def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: """Preprocess a list of keywords.""" if isinstance(values, str): values = [v.strip() for v in re.split(r'[,\s]', values)] elif isinstance(values, float) and math.isnan(values): values = [] if not isinstance(values, list): logger.debug(f"Unexpected keyword list {values}") values = [str(values)] kws = [get_keyword(v) for v in values] kw_list = [k for k in kws if k is not None] if len(kw_list) == 0: return None else: return kw_list ================================================ FILE: audiocraft/data/music_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Dataset of music tracks with rich metadata. """ from dataclasses import dataclass, field, fields, replace import gzip import json import logging from pathlib import Path import random import typing as tp import torch from .info_audio_dataset import ( InfoAudioDataset, AudioInfo, get_keyword_list, get_keyword, get_string ) from ..modules.conditioners import ( ConditioningAttributes, JointEmbedCondition, WavCondition, ) from ..utils.utils import warn_once logger = logging.getLogger(__name__) @dataclass class MusicInfo(AudioInfo): """Segment info augmented with music metadata. """ # music-specific metadata title: tp.Optional[str] = None artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits key: tp.Optional[str] = None bpm: tp.Optional[float] = None genre: tp.Optional[str] = None moods: tp.Optional[list] = None keywords: tp.Optional[list] = None description: tp.Optional[str] = None name: tp.Optional[str] = None instrument: tp.Optional[str] = None # original wav accompanying the metadata self_wav: tp.Optional[WavCondition] = None # dict mapping attributes names to tuple of wav, text and metadata joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) @property def has_music_meta(self) -> bool: return self.name is not None def to_condition_attributes(self) -> ConditioningAttributes: out = ConditioningAttributes() for _field in fields(self): key, value = _field.name, getattr(self, _field.name) if key == 'self_wav': out.wav[key] = value elif key == 'joint_embed': for embed_attribute, embed_cond in value.items(): out.joint_embed[embed_attribute] = embed_cond else: if isinstance(value, list): value = ' '.join(value) out.text[key] = value return out @staticmethod def attribute_getter(attribute): if attribute == 'bpm': preprocess_func = get_bpm elif attribute == 'key': preprocess_func = get_musical_key elif attribute in ['moods', 'keywords']: preprocess_func = get_keyword_list elif attribute in ['genre', 'name', 'instrument']: preprocess_func = get_keyword elif attribute in ['title', 'artist', 'description']: preprocess_func = get_string else: preprocess_func = None return preprocess_func @classmethod def from_dict(cls, dictionary: dict, fields_required: bool = False): _dictionary: tp.Dict[str, tp.Any] = {} # allow a subset of attributes to not be loaded from the dictionary # these attributes may be populated later post_init_attributes = ['self_wav', 'joint_embed'] optional_fields = ['keywords'] for _field in fields(cls): if _field.name in post_init_attributes: continue elif _field.name not in dictionary: if fields_required and _field.name not in optional_fields: raise KeyError(f"Unexpected missing key: {_field.name}") else: preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) value = dictionary[_field.name] if preprocess_func: value = preprocess_func(value) _dictionary[_field.name] = value return cls(**_dictionary) def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo: """Augment MusicInfo description with additional metadata fields and potential dropout. Additional textual attributes are added given probability 'merge_text_conditions_p' and the original textual description is dropped from the augmented description given probability drop_desc_p. Args: music_info (MusicInfo): The music metadata to augment. merge_text_p (float): Probability of merging additional metadata to the description. If provided value is 0, then no merging is performed. drop_desc_p (float): Probability of dropping the original description on text merge. if provided value is 0, then no drop out is performed. drop_other_p (float): Probability of dropping the other fields used for text augmentation. Returns: MusicInfo: The MusicInfo with augmented textual description. """ def is_valid_field(field_name: str, field_value: tp.Any) -> bool: valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords'] valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list)) keep_field = random.uniform(0, 1) < drop_other_p return valid_field_name and valid_field_value and keep_field def process_value(v: tp.Any) -> str: if isinstance(v, (int, float, str)): return str(v) if isinstance(v, list): return ", ".join(v) else: raise ValueError(f"Unknown type for text value! ({type(v), v})") description = music_info.description metadata_text = "" if random.uniform(0, 1) < merge_text_p: meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}' for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))] random.shuffle(meta_pairs) metadata_text = ". ".join(meta_pairs) description = description if not random.uniform(0, 1) < drop_desc_p else None logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}") if description is None: description = metadata_text if len(metadata_text) > 1 else None else: description = ". ".join([description.rstrip('.'), metadata_text]) description = description.strip() if description else None music_info = replace(music_info) music_info.description = description return music_info class Paraphraser: def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.): self.paraphrase_p = paraphrase_p open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open with open_fn(paraphrase_source, 'rb') as f: # type: ignore self.paraphrase_source = json.loads(f.read()) logger.info(f"loaded paraphrasing source from: {paraphrase_source}") def sample_paraphrase(self, audio_path: str, description: str): if random.random() >= self.paraphrase_p: return description info_path = Path(audio_path).with_suffix('.json') if info_path not in self.paraphrase_source: warn_once(logger, f"{info_path} not in paraphrase source!") return description new_desc = random.choice(self.paraphrase_source[info_path]) logger.debug(f"{description} -> {new_desc}") return new_desc class MusicDataset(InfoAudioDataset): """Music dataset is an AudioDataset with music-related metadata. Args: info_fields_required (bool): Whether to enforce having required fields. merge_text_p (float): Probability of merging additional metadata to the description. drop_desc_p (float): Probability of dropping the original description on text merge. drop_other_p (float): Probability of dropping the other fields used for text augmentation. joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned. paraphrase_source (str, optional): Path to the .json or .json.gz file containing the paraphrases for the description. The json should be a dict with keys are the original info path (e.g. track_path.json) and each value is a list of possible paraphrased. paraphrase_p (float): probability of taking a paraphrase. See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. """ def __init__(self, *args, info_fields_required: bool = True, merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0., joint_embed_attributes: tp.List[str] = [], paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0, **kwargs): kwargs['return_info'] = True # We require the info for each song of the dataset. super().__init__(*args, **kwargs) self.info_fields_required = info_fields_required self.merge_text_p = merge_text_p self.drop_desc_p = drop_desc_p self.drop_other_p = drop_other_p self.joint_embed_attributes = joint_embed_attributes self.paraphraser = None if paraphrase_source is not None: self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p) def __getitem__(self, index): wav, info = super().__getitem__(index) info_data = info.to_dict() music_info_path = Path(info.meta.path).with_suffix('.json') if Path(music_info_path).exists(): with open(music_info_path, 'r') as json_file: music_data = json.load(json_file) music_data.update(info_data) music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required) if self.paraphraser is not None: music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description) if self.merge_text_p: music_info = augment_music_info_description( music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p) else: music_info = MusicInfo.from_dict(info_data, fields_required=False) music_info.self_wav = WavCondition( wav=wav[None], length=torch.tensor([info.n_frames]), sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) for att in self.joint_embed_attributes: att_value = getattr(music_info, att) joint_embed_cond = JointEmbedCondition( wav[None], [att_value], torch.tensor([info.n_frames]), sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) music_info.joint_embed[att] = joint_embed_cond return wav, music_info def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]: """Preprocess key keywords, discarding them if there are multiple key defined.""" if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': return None elif ',' in value: # For now, we discard when multiple keys are defined separated with comas return None else: return value.strip().lower() def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]: """Preprocess to a float.""" if value is None: return None try: return float(value) except ValueError: return None ================================================ FILE: audiocraft/data/sound_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Dataset of audio with a simple description. """ from dataclasses import dataclass, fields, replace import json from pathlib import Path import random import typing as tp import numpy as np import torch from .info_audio_dataset import ( InfoAudioDataset, get_keyword_or_keyword_list ) from ..modules.conditioners import ( ConditioningAttributes, SegmentWithAttributes, WavCondition, ) EPS = torch.finfo(torch.float32).eps TARGET_LEVEL_LOWER = -35 TARGET_LEVEL_UPPER = -15 @dataclass class SoundInfo(SegmentWithAttributes): """Segment info augmented with Sound metadata. """ description: tp.Optional[str] = None self_wav: tp.Optional[torch.Tensor] = None @property def has_sound_meta(self) -> bool: return self.description is not None def to_condition_attributes(self) -> ConditioningAttributes: out = ConditioningAttributes() for _field in fields(self): key, value = _field.name, getattr(self, _field.name) if key == 'self_wav': out.wav[key] = value else: out.text[key] = value return out @staticmethod def attribute_getter(attribute): if attribute == 'description': preprocess_func = get_keyword_or_keyword_list else: preprocess_func = None return preprocess_func @classmethod def from_dict(cls, dictionary: dict, fields_required: bool = False): _dictionary: tp.Dict[str, tp.Any] = {} # allow a subset of attributes to not be loaded from the dictionary # these attributes may be populated later post_init_attributes = ['self_wav'] for _field in fields(cls): if _field.name in post_init_attributes: continue elif _field.name not in dictionary: if fields_required: raise KeyError(f"Unexpected missing key: {_field.name}") else: preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name) value = dictionary[_field.name] if preprocess_func: value = preprocess_func(value) _dictionary[_field.name] = value return cls(**_dictionary) class SoundDataset(InfoAudioDataset): """Sound audio dataset: Audio dataset with environmental sound-specific metadata. Args: info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata. external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset. The metadata files contained in this folder are expected to match the stem of the audio file with a json extension. aug_p (float): Probability of performing audio mixing augmentation on the batch. mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation. mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation. mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation. mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation. kwargs: Additional arguments for AudioDataset. See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments. """ def __init__( self, *args, info_fields_required: bool = True, external_metadata_source: tp.Optional[str] = None, aug_p: float = 0., mix_p: float = 0., mix_snr_low: int = -5, mix_snr_high: int = 5, mix_min_overlap: float = 0.5, **kwargs ): kwargs['return_info'] = True # We require the info for each song of the dataset. super().__init__(*args, **kwargs) self.info_fields_required = info_fields_required self.external_metadata_source = external_metadata_source self.aug_p = aug_p self.mix_p = mix_p if self.aug_p > 0: assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0" assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio" self.mix_snr_low = mix_snr_low self.mix_snr_high = mix_snr_high self.mix_min_overlap = mix_min_overlap def _get_info_path(self, path: tp.Union[str, Path]) -> Path: """Get path of JSON with metadata (description, etc.). If there exists a JSON with the same name as 'path.name', then it will be used. Else, such JSON will be searched for in an external json source folder if it exists. """ info_path = Path(path).with_suffix('.json') if Path(info_path).exists(): return info_path elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists(): return Path(self.external_metadata_source) / info_path.name else: raise Exception(f"Unable to find a metadata JSON for path: {path}") def __getitem__(self, index): wav, info = super().__getitem__(index) info_data = info.to_dict() info_path = self._get_info_path(info.meta.path) if Path(info_path).exists(): with open(info_path, 'r') as json_file: sound_data = json.load(json_file) sound_data.update(info_data) sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required) # if there are multiple descriptions, sample one randomly if isinstance(sound_info.description, list): sound_info.description = random.choice(sound_info.description) else: sound_info = SoundInfo.from_dict(info_data, fields_required=False) sound_info.self_wav = WavCondition( wav=wav[None], length=torch.tensor([info.n_frames]), sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) return wav, sound_info def collater(self, samples): # when training, audio mixing is performed in the collate function wav, sound_info = super().collater(samples) # SoundDataset always returns infos if self.aug_p > 0: wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p, snr_low=self.mix_snr_low, snr_high=self.mix_snr_high, min_overlap=self.mix_min_overlap) return wav, sound_info def rms_f(x: torch.Tensor) -> torch.Tensor: return (x ** 2).mean(1).pow(0.5) def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor: """Normalize the signal to the target level.""" rms = rms_f(audio) scalar = 10 ** (target_level / 20) / (rms + EPS) audio = audio * scalar.unsqueeze(1) return audio def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor: return (abs(audio) > clipping_threshold).any(1) def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor: start = random.randint(0, int(src.shape[1] * (1 - min_overlap))) remainder = src.shape[1] - start if dst.shape[1] > remainder: src[:, start:] = src[:, start:] + dst[:, :remainder] else: src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst return src def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float, target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor: """Function to mix clean speech and noise at various SNR levels. Args: clean (torch.Tensor): Clean audio source to mix, of shape [B, T]. noise (torch.Tensor): Noise audio source to mix, of shape [B, T]. snr (int): SNR level when mixing. min_overlap (float): Minimum overlap between the two mixed sources. target_level (int): Gain level in dB. clipping_threshold (float): Threshold for clipping the audio. Returns: torch.Tensor: The mixed audio, of shape [B, T]. """ if clean.shape[1] > noise.shape[1]: noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1])) else: noise = noise[:, :clean.shape[1]] # normalizing to -25 dB FS clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS) clean = normalize(clean, target_level) rmsclean = rms_f(clean) noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS) noise = normalize(noise, target_level) rmsnoise = rms_f(noise) # set the noise level for a given SNR noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1) noisenewlevel = noise * noisescalar # mix noise and clean speech noisyspeech = mix_pair(clean, noisenewlevel, min_overlap) # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value # there is a chance of clipping that might happen with very less probability, which is not a major issue. noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER) rmsnoisy = rms_f(noisyspeech) scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1) noisyspeech = noisyspeech * scalarnoisy clean = clean * scalarnoisy noisenewlevel = noisenewlevel * scalarnoisy # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly clipped = is_clipped(noisyspeech) if clipped.any(): noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS) noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel return noisyspeech def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float): if snr_low == snr_high: snr = snr_low else: snr = np.random.randint(snr_low, snr_high) mix = snr_mixer(src, dst, snr, min_overlap) return mix def mix_text(src_text: str, dst_text: str): """Mix text from different sources by concatenating them.""" if src_text == dst_text: return src_text return src_text + " " + dst_text def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float, snr_low: int, snr_high: int, min_overlap: float): """Mix samples within a batch, summing the waveforms and concatenating the text infos. Args: wavs (torch.Tensor): Audio tensors of shape [B, C, T]. infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio. aug_p (float): Augmentation probability. mix_p (float): Proportion of items in the batch to mix (and merge) together. snr_low (int): Lowerbound for sampling SNR. snr_high (int): Upperbound for sampling SNR. min_overlap (float): Minimum overlap between mixed samples. Returns: tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs and mixed SoundInfo for the given batch. """ # no mixing to perform within the batch if mix_p == 0: return wavs, infos if random.uniform(0, 1) < aug_p: # perform all augmentations on waveforms as [B, T] # randomly picking pairs of audio to mix assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}" wavs = wavs.mean(dim=1, keepdim=False) B, T = wavs.shape k = int(mix_p * B) mixed_sources_idx = torch.randperm(B)[:k] mixed_targets_idx = torch.randperm(B)[:k] aug_wavs = snr_mix( wavs[mixed_sources_idx], wavs[mixed_targets_idx], snr_low, snr_high, min_overlap, ) # mixing textual descriptions in metadata descriptions = [info.description for info in infos] aug_infos = [] for i, j in zip(mixed_sources_idx, mixed_targets_idx): text = mix_text(descriptions[i], descriptions[j]) m = replace(infos[i]) m.description = text aug_infos.append(m) # back to [B, C, T] aug_wavs = aug_wavs.unsqueeze(1) assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch." assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}" assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch" return aug_wavs, aug_infos # [B, C, T] else: # randomly pick samples in the batch to match # the batch size when performing audio mixing B, C, T = wavs.shape k = int(mix_p * B) wav_idx = torch.randperm(B)[:k] wavs = wavs[wav_idx] infos = [infos[i] for i in wav_idx] assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch" return wavs, infos # [B, C, T] ================================================ FILE: audiocraft/data/zip.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Utility for reading some info from inside a zip file. """ import typing import zipfile from dataclasses import dataclass from functools import lru_cache from typing_extensions import Literal DEFAULT_SIZE = 32 MODE = Literal['r', 'w', 'x', 'a'] @dataclass(order=True) class PathInZip: """Hold a path of file within a zip file. Args: path (str): The convention is :. Let's assume there is a zip file /some/location/foo.zip and inside of it is a json file located at /data/file1.json, Then we expect path = "/some/location/foo.zip:/data/file1.json". """ INFO_PATH_SEP = ':' zip_path: str file_path: str def __init__(self, path: str) -> None: split_path = path.split(self.INFO_PATH_SEP) assert len(split_path) == 2 self.zip_path, self.file_path = split_path @classmethod def from_paths(cls, zip_path: str, file_path: str): return cls(zip_path + cls.INFO_PATH_SEP + file_path) def __str__(self) -> str: return self.zip_path + self.INFO_PATH_SEP + self.file_path def _open_zip(path: str, mode: MODE = 'r'): return zipfile.ZipFile(path, mode) _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip) def set_zip_cache_size(max_size: int): """Sets the maximal LRU caching for zip file opening. Args: max_size (int): the maximal LRU cache. """ global _cached_open_zip _cached_open_zip = lru_cache(max_size)(_open_zip) def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: """Opens a file stored inside a zip and returns a file-like object. Args: path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of. mode (str): The mode in which to open the file with. Returns: A file-like object for PathInZip. """ zf = _cached_open_zip(path_in_zip.zip_path) return zf.open(path_in_zip.file_path) ================================================ FILE: audiocraft/environment.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Provides cluster and tools configuration across clusters (slurm, dora, utilities). """ import logging import os from pathlib import Path import re import typing as tp import omegaconf from .utils.cluster import _guess_cluster_type logger = logging.getLogger(__name__) class AudioCraftEnvironment: """Environment configuration for teams and clusters. AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment provides pointers to a reference folder resolved automatically across clusters that is shared across team members, allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters. The cluster type is identified automatically and base configuration file is read from config/teams.yaml. Use the following environment variables to specify the cluster, team or configuration: AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type cannot be inferred automatically. AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration. If not set, configuration is read from config/teams.yaml. AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team. Cluster configuration are shared across teams to match compute allocation, specify your cluster configuration in the configuration file under a key mapping your team name. """ _instance = None DEFAULT_TEAM = "default" def __init__(self) -> None: """Loads configuration.""" self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM) cluster_type = _guess_cluster_type() cluster = os.getenv( "AUDIOCRAFT_CLUSTER", cluster_type.value ) logger.info("Detecting cluster type %s", cluster_type) self.cluster: str = cluster config_path = os.getenv( "AUDIOCRAFT_CONFIG", Path(__file__) .parent.parent.joinpath("config/teams", self.team) .with_suffix(".yaml"), ) self.config = omegaconf.OmegaConf.load(config_path) self._dataset_mappers = [] cluster_config = self._get_cluster_config() if "dataset_mappers" in cluster_config: for pattern, repl in cluster_config["dataset_mappers"].items(): regex = re.compile(pattern) self._dataset_mappers.append((regex, repl)) def _get_cluster_config(self) -> omegaconf.DictConfig: assert isinstance(self.config, omegaconf.DictConfig) return self.config[self.cluster] @classmethod def instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance @classmethod def reset(cls): """Clears the environment and forces a reload on next invocation.""" cls._instance = None @classmethod def get_team(cls) -> str: """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. If not defined, defaults to "labs". """ return cls.instance().team @classmethod def get_cluster(cls) -> str: """Gets the detected cluster. This value can be overridden by the AUDIOCRAFT_CLUSTER env var. """ return cls.instance().cluster @classmethod def get_dora_dir(cls) -> Path: """Gets the path to the dora directory for the current team and cluster. Value is overridden by the AUDIOCRAFT_DORA_DIR env var. """ cluster_config = cls.instance()._get_cluster_config() dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"]) logger.warning(f"Dora directory: {dora_dir}") return Path(dora_dir) @classmethod def get_reference_dir(cls) -> Path: """Gets the path to the reference directory for the current team and cluster. Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var. """ cluster_config = cls.instance()._get_cluster_config() return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"])) @classmethod def get_slurm_exclude(cls) -> tp.Optional[str]: """Get the list of nodes to exclude for that cluster.""" cluster_config = cls.instance()._get_cluster_config() return cluster_config.get("slurm_exclude") @classmethod def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str: """Gets the requested partitions for the current team and cluster as a comma-separated string. Args: partition_types (list[str], optional): partition types to retrieve. Values must be from ['global', 'team']. If not provided, the global partition is returned. """ if not partition_types: partition_types = ["global"] cluster_config = cls.instance()._get_cluster_config() partitions = [ cluster_config["partitions"][partition_type] for partition_type in partition_types ] return ",".join(partitions) @classmethod def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: """Converts reference placeholder in path with configured reference dir to resolve paths. Args: path (str or Path): Path to resolve. Returns: Path: Resolved path. """ path = str(path) if path.startswith("//reference"): reference_dir = cls.get_reference_dir() logger.warn(f"Reference directory: {reference_dir}") assert ( reference_dir.exists() and reference_dir.is_dir() ), f"Reference directory does not exist: {reference_dir}." path = re.sub("^//reference", str(reference_dir), path) return Path(path) @classmethod def apply_dataset_mappers(cls, path: str) -> str: """Applies dataset mapping regex rules as defined in the configuration. If no rules are defined, the path is returned as-is. """ instance = cls.instance() for pattern, repl in instance._dataset_mappers: path = pattern.sub(repl, path) return path ================================================ FILE: audiocraft/grids/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Dora Grids.""" ================================================ FILE: audiocraft/grids/_base_explorers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod import time import typing as tp from dora import Explorer import treetable as tt def get_sheep_ping(sheep) -> tp.Optional[str]: """Return the amount of time since the Sheep made some update to its log. Returns a str using the relevant time unit.""" ping = None if sheep.log is not None and sheep.log.exists(): delta = time.time() - sheep.log.stat().st_mtime if delta > 3600 * 24: ping = f'{delta / (3600 * 24):.1f}d' elif delta > 3600: ping = f'{delta / (3600):.1f}h' elif delta > 60: ping = f'{delta / 60:.1f}m' else: ping = f'{delta:.1f}s' return ping class BaseExplorer(ABC, Explorer): """Base explorer for AudioCraft grids. All task specific solvers are expected to implement the `get_grid_metrics` method to specify logic about metrics to display for a given task. If additional stages are used, the child explorer must define how to handle these new stages in the `process_history` and `process_sheep` methods. """ def stages(self): return ["train", "valid", "evaluate"] def get_grid_meta(self): """Returns the list of Meta information to display for each XP/job. """ return [ tt.leaf("index", align=">"), tt.leaf("name", wrap=140), tt.leaf("state"), tt.leaf("sig", align=">"), tt.leaf("sid", align="<"), ] @abstractmethod def get_grid_metrics(self): """Return the metrics that should be displayed in the tracking table. """ ... def process_sheep(self, sheep, history): train = { "epoch": len(history), } parts = {"train": train} for metrics in history: for key, sub in metrics.items(): part = parts.get(key, {}) if 'duration' in sub: # Convert to minutes for readability. sub['duration'] = sub['duration'] / 60. part.update(sub) parts[key] = part ping = get_sheep_ping(sheep) if ping is not None: for name in self.stages(): if name not in parts: parts[name] = {} # Add the ping to each part for convenience. parts[name]['ping'] = ping return parts ================================================ FILE: audiocraft/grids/audiogen/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """AudioGen grids.""" ================================================ FILE: audiocraft/grids/audiogen/audiogen_base_16khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from ..musicgen._explorers import LMExplorer from ...environment import AudioCraftEnvironment @LMExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=64, partition=partitions) launcher.bind_(solver='audiogen/audiogen_base_16khz') # replace this by the desired environmental sound dataset launcher.bind_(dset='internal/sounds_16khz') fsdp = {'autocast': False, 'fsdp.use': True} medium = {'model/lm/model_scale': 'medium'} launcher.bind_(fsdp) launcher(medium) ================================================ FILE: audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Evaluation with objective metrics for the pretrained AudioGen models. This grid takes signature from the training grid and runs evaluation-only stage. When running the grid for the first time, please use: REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval and re-use the REGEN=1 option when the grid is changed to force regenerating it. Note that you need the proper metrics external libraries setup to use all the objective metrics activated in this grid. Refer to the README for more information. """ import os from ..musicgen._explorers import GenerationEvalExplorer from ...environment import AudioCraftEnvironment from ... import train def eval(launcher, batch_size: int = 32): opts = { 'dset': 'audio/audiocaps_16khz', 'solver/audiogen/evaluation': 'objective_eval', 'execute_only': 'evaluate', '+dataset.evaluate.batch_size': batch_size, '+metrics.fad.tf.batch_size': 32, } # binary for FAD computation: replace this path with your own path metrics_opts = { 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' } opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} opt2 = {'transformer_lm.two_step_cfg': True} sub = launcher.bind(opts) sub.bind_(metrics_opts) # base objective metrics sub(opt1, opt2) @GenerationEvalExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=4, partition=partitions) if 'REGEN' not in os.environ: folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] with launcher.job_array(): for sig in folder.iterdir(): if not sig.is_symlink(): continue xp = train.main.get_xp_from_sig(sig.name) launcher(xp.argv) return audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz") audiogen_base.bind_({'autocast': False, 'fsdp.use': True}) audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'}) audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'}) eval(audiogen_base_medium, batch_size=128) ================================================ FILE: audiocraft/grids/compression/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """EnCodec grids.""" ================================================ FILE: audiocraft/grids/compression/_explorers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import treetable as tt from .._base_explorers import BaseExplorer class CompressionExplorer(BaseExplorer): eval_metrics = ["sisnr", "visqol"] def stages(self): return ["train", "valid", "evaluate"] def get_grid_meta(self): """Returns the list of Meta information to display for each XP/job. """ return [ tt.leaf("index", align=">"), tt.leaf("name", wrap=140), tt.leaf("state"), tt.leaf("sig", align=">"), ] def get_grid_metrics(self): """Return the metrics that should be displayed in the tracking table. """ return [ tt.group( "train", [ tt.leaf("epoch"), tt.leaf("bandwidth", ".2f"), tt.leaf("adv", ".4f"), tt.leaf("d_loss", ".4f"), ], align=">", ), tt.group( "valid", [ tt.leaf("bandwidth", ".2f"), tt.leaf("adv", ".4f"), tt.leaf("msspec", ".4f"), tt.leaf("sisnr", ".2f"), ], align=">", ), tt.group( "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">" ), ] ================================================ FILE: audiocraft/grids/compression/debug.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Grid search file, simply list all the exp you want in `explorer`. Any new exp added there will be scheduled. You can cancel and experiment by commenting its line. This grid is a minimal example for debugging compression task and how to override parameters directly in a grid. Learn more about dora grids: https://github.com/facebookresearch/dora """ from ._explorers import CompressionExplorer from ...environment import AudioCraftEnvironment @CompressionExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=2, partition=partitions) launcher.bind_(solver='compression/debug') with launcher.job_array(): # base debug task using config from solver=compression/debug launcher() # we can override parameters in the grid to launch additional xps launcher({'rvq.bins': 2048, 'rvq.n_q': 4}) ================================================ FILE: audiocraft/grids/compression/encodec_audiogen_16khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Grid search file, simply list all the exp you want in `explorer`. Any new exp added there will be scheduled. You can cancel and experiment by commenting its line. This grid shows how to train the new AudioGen EnCodec model at 16 kHz. """ from ._explorers import CompressionExplorer from ...environment import AudioCraftEnvironment @CompressionExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=8, partition=partitions) # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz launcher.bind_(solver='compression/encodec_audiogen_16khz') # replace this by the desired sound dataset launcher.bind_(dset='internal/sounds_16khz') # launch xp launcher() ================================================ FILE: audiocraft/grids/compression/encodec_base_24khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Grid search file, simply list all the exp you want in `explorer`. Any new exp added there will be scheduled. You can cancel and experiment by commenting its line. This grid shows how to train a base causal EnCodec model at 24 kHz. """ from ._explorers import CompressionExplorer from ...environment import AudioCraftEnvironment @CompressionExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=8, partition=partitions) # base causal EnCodec trained on monophonic audio sampled at 24 kHz launcher.bind_(solver='compression/encodec_base_24khz') # replace this by the desired dataset launcher.bind_(dset='audio/example') # launch xp launcher() ================================================ FILE: audiocraft/grids/compression/encodec_musicgen_32khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Grid search file, simply list all the exp you want in `explorer`. Any new exp added there will be scheduled. You can cancel and experiment by commenting its line. This grid shows how to train a MusicGen EnCodec model at 32 kHz. """ from ._explorers import CompressionExplorer from ...environment import AudioCraftEnvironment @CompressionExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=8, partition=partitions) # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz launcher.bind_(solver='compression/encodec_musicgen_32khz') # replace this by the desired music dataset launcher.bind_(dset='internal/music_400k_32khz') # launch xp launcher() launcher({ 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol', 'label': 'visqol', 'evaluate.metrics.visqol': True }) ================================================ FILE: audiocraft/grids/diffusion/4_bands_base_32khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Training of the 4 diffusion models described in "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" (paper link). """ from ._explorers import DiffusionExplorer @DiffusionExplorer def explorer(launcher): launcher.slurm_(gpus=4, partition='learnfair') launcher.bind_({'solver': 'diffusion/default', 'dset': 'internal/music_10k_32khz'}) with launcher.job_array(): launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4}) launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4}) launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4}) launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75}) ================================================ FILE: audiocraft/grids/diffusion/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Diffusion grids.""" ================================================ FILE: audiocraft/grids/diffusion/_explorers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import treetable as tt from .._base_explorers import BaseExplorer class DiffusionExplorer(BaseExplorer): eval_metrics = ["sisnr", "visqol"] def stages(self): return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] def get_grid_meta(self): """Returns the list of Meta information to display for each XP/job. """ return [ tt.leaf("index", align=">"), tt.leaf("name", wrap=140), tt.leaf("state"), tt.leaf("sig", align=">"), ] def get_grid_metrics(self): """Return the metrics that should be displayed in the tracking table. """ return [ tt.group( "train", [ tt.leaf("epoch"), tt.leaf("loss", ".3%"), ], align=">", ), tt.group( "valid", [ tt.leaf("loss", ".3%"), # tt.leaf("loss_0", ".3%"), ], align=">", ), tt.group( "valid_ema", [ tt.leaf("loss", ".3%"), # tt.leaf("loss_0", ".3%"), ], align=">", ), tt.group( "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), tt.leaf("rvm_3", ".4f"), ], align=">" ), tt.group( "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), tt.leaf("rvm_3", ".4f")], align=">" ), ] ================================================ FILE: audiocraft/grids/musicgen/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """MusicGen grids.""" ================================================ FILE: audiocraft/grids/musicgen/_explorers.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import treetable as tt from .._base_explorers import BaseExplorer class LMExplorer(BaseExplorer): eval_metrics: tp.List[str] = [] def stages(self) -> tp.List[str]: return ['train', 'valid'] def get_grid_metrics(self): """Return the metrics that should be displayed in the tracking table.""" return [ tt.group( 'train', [ tt.leaf('epoch'), tt.leaf('duration', '.1f'), # duration in minutes tt.leaf('ping'), tt.leaf('ce', '.4f'), # cross entropy tt.leaf("ppl", '.3f'), # perplexity ], align='>', ), tt.group( 'valid', [ tt.leaf('ce', '.4f'), tt.leaf('ppl', '.3f'), tt.leaf('best_ppl', '.3f'), ], align='>', ), ] def process_sheep(self, sheep, history): parts = super().process_sheep(sheep, history) track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher'] best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} def comparator(mode, a, b): return a < b if mode == 'lower' else a > b for metrics in history: for key, sub in metrics.items(): for metric in track_by: # for the validation set, keep track of best metrics (ppl in this example) # this is so we can conveniently compare metrics between runs in the grid if key == 'valid' and metric in sub and comparator( track_by[metric], sub[metric], best_metrics[metric] ): best_metrics[metric] = sub[metric] if 'valid' in parts: parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) return parts class GenerationEvalExplorer(BaseExplorer): eval_metrics: tp.List[str] = [] def stages(self) -> tp.List[str]: return ['evaluate'] def get_grid_metrics(self): """Return the metrics that should be displayed in the tracking table.""" return [ tt.group( 'evaluate', [ tt.leaf('epoch', '.3f'), tt.leaf('duration', '.1f'), tt.leaf('ping'), tt.leaf('ce', '.4f'), tt.leaf('ppl', '.3f'), tt.leaf('fad', '.3f'), tt.leaf('kld', '.3f'), tt.leaf('text_consistency', '.3f'), tt.leaf('chroma_cosine', '.3f'), ], align='>', ), ] ================================================ FILE: audiocraft/grids/musicgen/musicgen_base_32khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from ._explorers import LMExplorer from ...environment import AudioCraftEnvironment @LMExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=32, partition=partitions) launcher.bind_(solver='musicgen/musicgen_base_32khz') # replace this by the desired music dataset launcher.bind_(dset='internal/music_400k_32khz') fsdp = {'autocast': False, 'fsdp.use': True} medium = {'model/lm/model_scale': 'medium'} large = {'model/lm/model_scale': 'large'} cfg_low = {'classifier_free_guidance.training_dropout': 0.2} wd_low = {'conditioners.description.t5.word_dropout': 0.2} adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} launcher.bind_(fsdp) launcher.slurm_(gpus=32).bind_(label='32gpus') with launcher.job_array(): sub = launcher.bind() sub() launcher.slurm_(gpus=64).bind_(label='64gpus') with launcher.job_array(): sub = launcher.bind() sub(medium, adam) launcher.slurm_(gpus=96).bind_(label='96gpus') with launcher.job_array(): sub = launcher.bind() sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) ================================================ FILE: audiocraft/grids/musicgen/musicgen_base_cached_32khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from ._explorers import LMExplorer from ...environment import AudioCraftEnvironment @LMExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=32, partition=partitions) launcher.bind_(solver='musicgen/musicgen_base_32khz') # replace this by the desired music dataset launcher.bind_(dset='internal/music_400k_32khz') fsdp = {'autocast': False, 'fsdp.use': True} medium = {'model/lm/model_scale': 'medium'} large = {'model/lm/model_scale': 'large'} cfg_low = {'classifier_free_guidance.training_dropout': 0.2} wd_low = {'conditioners.description.t5.word_dropout': 0.2} adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} # BEGINNING OF CACHE WRITING JOBS. cache_write = { 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', 'cache.write': True, 'generate.every': 500, 'evaluate.every': 500, 'logging.log_updates': 50, } cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'}) cache_sub.bind_({'deadlock.use': True}) cache_sub.slurm_(gpus=8) with launcher.job_array(): num_shards = 10 # total number of jobs running in parallel. for shard in range(0, num_shards): launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard}) # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE, # OR SUFFICIENTLY AHEAD. return cache = { 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', } launcher.bind_(fsdp, cache) launcher.slurm_(gpus=32).bind_(label='32gpus') with launcher.job_array(): sub = launcher.bind() sub() launcher.slurm_(gpus=64).bind_(label='64gpus') with launcher.job_array(): sub = launcher.bind() sub(medium, adam) launcher.slurm_(gpus=96).bind_(label='96gpus') with launcher.job_array(): sub = launcher.bind() sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) ================================================ FILE: audiocraft/grids/musicgen/musicgen_clapemb_32khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from ._explorers import LMExplorer from ...environment import AudioCraftEnvironment @LMExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=32, partition=partitions) launcher.bind_(solver='musicgen/musicgen_base_32khz') # replace this by the desired music dataset launcher.bind_(dset='internal/music_400k_32khz') launcher.bind_(conditioner='clapemb2music') fsdp = {'autocast': False, 'fsdp.use': True} cache_path = {'conditioners.description.clap.cache_path': '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'} text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5} launcher.bind_(fsdp) launcher.slurm_(gpus=32).bind_(label='32gpus') with launcher.job_array(): launcher() launcher(text_wav_training_opt) launcher(cache_path) launcher(cache_path, text_wav_training_opt) ================================================ FILE: audiocraft/grids/musicgen/musicgen_melody_32khz.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from ._explorers import LMExplorer from ...environment import AudioCraftEnvironment @LMExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=32, partition=partitions) launcher.bind_(solver='musicgen/musicgen_melody_32khz') # replace this by the desired music dataset launcher.bind_(dset='internal/music_400k_32khz') fsdp = {'autocast': False, 'fsdp.use': True} medium = {'model/lm/model_scale': 'medium'} large = {'model/lm/model_scale': 'large'} cfg_low = {'classifier_free_guidance.training_dropout': 0.2} wd_low = {'conditioners.description.t5.word_dropout': 0.2} adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} cache_path = {'conditioners.self_wav.chroma_stem.cache_path': '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'} # CACHE GENERATION JOBS n_cache_gen_jobs = 4 gen_sub = launcher.slurm(gpus=1) gen_sub.bind_( cache_path, { # the cache is always computed over the whole file, so duration doesn't matter here. 'dataset.segment_duration': 2., 'dataset.batch_size': 8, 'dataset.train.permutation_on_files': True, # try to not repeat files. 'optim.epochs': 10, 'model/lm/model_scale': 'xsmall', }) with gen_sub.job_array(): for gen_job in range(n_cache_gen_jobs): gen_sub({'dataset.train.shuffle_seed': gen_job}) # ACTUAL TRAINING JOBS. launcher.bind_(fsdp) launcher.slurm_(gpus=32).bind_(label='32gpus') with launcher.job_array(): sub = launcher.bind() sub() sub(cache_path) launcher.slurm_(gpus=64).bind_(label='64gpus') with launcher.job_array(): sub = launcher.bind() sub(medium, adam) launcher.slurm_(gpus=96).bind_(label='96gpus') with launcher.job_array(): sub = launcher.bind() sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) ================================================ FILE: audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Evaluation with objective metrics for the pretrained MusicGen models. This grid takes signature from the training grid and runs evaluation-only stage. When running the grid for the first time, please use: REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval and re-use the REGEN=1 option when the grid is changed to force regenerating it. Note that you need the proper metrics external libraries setup to use all the objective metrics activated in this grid. Refer to the README for more information. """ import os from ._explorers import GenerationEvalExplorer from ...environment import AudioCraftEnvironment from ... import train def eval(launcher, batch_size: int = 32, eval_melody: bool = False): opts = { 'dset': 'audio/musiccaps_32khz', 'solver/musicgen/evaluation': 'objective_eval', 'execute_only': 'evaluate', '+dataset.evaluate.batch_size': batch_size, '+metrics.fad.tf.batch_size': 16, } # chroma-specific evaluation chroma_opts = { 'dset': 'internal/music_400k_32khz', 'dataset.evaluate.segment_duration': 30, 'dataset.evaluate.num_samples': 1000, 'evaluate.metrics.chroma_cosine': True, 'evaluate.metrics.fad': False, 'evaluate.metrics.kld': False, 'evaluate.metrics.text_consistency': False, } # binary for FAD computation: replace this path with your own path metrics_opts = { 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' } opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} opt2 = {'transformer_lm.two_step_cfg': True} sub = launcher.bind(opts) sub.bind_(metrics_opts) # base objective metrics sub(opt1, opt2) if eval_melody: # chroma-specific metrics sub(opt1, opt2, chroma_opts) @GenerationEvalExplorer def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=4, partition=partitions) if 'REGEN' not in os.environ: folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] with launcher.job_array(): for sig in folder.iterdir(): if not sig.is_symlink(): continue xp = train.main.get_xp_from_sig(sig.name) launcher(xp.argv) return with launcher.job_array(): musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz") musicgen_base.bind_({'autocast': False, 'fsdp.use': True}) # base musicgen models musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'}) eval(musicgen_base_small, batch_size=128) musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'}) musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'}) eval(musicgen_base_medium, batch_size=128) musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'}) musicgen_base_large.bind_({'model/lm/model_scale': 'large'}) eval(musicgen_base_large, batch_size=128) # melody musicgen model musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz") musicgen_melody.bind_({'autocast': False, 'fsdp.use': True}) musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'}) musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'}) eval(musicgen_melody_medium, batch_size=128, eval_melody=True) ================================================ FILE: audiocraft/losses/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Loss related classes and functions. In particular the loss balancer from EnCodec, and the usual spectral losses.""" # flake8: noqa from .balancer import Balancer from .sisnr import SISNR from .stftloss import ( LogSTFTMagnitudeLoss, MRSTFTLoss, SpectralConvergenceLoss, STFTLoss ) from .specloss import ( MelSpectrogramL1Loss, MultiScaleMelSpectrogramLoss, ) ================================================ FILE: audiocraft/losses/balancer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import flashy import torch from torch import autograd class Balancer: """Loss balancer. The loss balancer combines losses together to compute gradients for the backward. Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...` not having any dependence on `f`, the balancer can efficiently normalize the partial gradients `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown. Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be (with `avg` an exponential moving average over the updates), G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i) If `balance_grads` is False, this is deactivated, and instead the gradient will just be the standard sum of the partial gradients with the given weights. A call to the backward method of the balancer will compute the the partial gradients, combining all the losses and potentially rescaling the gradients, which can help stabilize the training and reason about multiple losses with varying scales. The obtained gradient with respect to `y` is then back-propagated to `f(...)`. Expected usage: weights = {'loss_a': 1, 'loss_b': 4} balancer = Balancer(weights, ...) losses: dict = {} losses['loss_a'] = compute_loss_a(x, y) losses['loss_b'] = compute_loss_b(x, y) if model.training(): effective_loss = balancer.backward(losses, x) Args: weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys from the backward method to match the weights keys to assign weight to each of the provided loss. balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the overall gradient, rather than a constant multiplier. total_norm (float): Reference norm when rescaling gradients, ignored otherwise. emay_decay (float): EMA decay for averaging the norms. per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds when rescaling the gradients. epsilon (float): Epsilon value for numerical stability. monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients coming from each loss, when calling `backward()`. """ def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1., ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, monitor: bool = False): self.weights = weights self.per_batch_item = per_batch_item self.total_norm = total_norm or 1. self.averager = flashy.averager(ema_decay or 1.) self.epsilon = epsilon self.monitor = monitor self.balance_grads = balance_grads self._metrics: tp.Dict[str, tp.Any] = {} @property def metrics(self): return self._metrics def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor: """Compute the backward and return the effective train loss, e.g. the loss obtained from computing the effective weights. If `balance_grads` is True, the effective weights are the one that needs to be applied to each gradient to respect the desired relative scale of gradients coming from each loss. Args: losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`. input (torch.Tensor): the input of the losses, typically the output of the model. This should be the single point of dependence between the losses and the model being trained. """ norms = {} grads = {} for name, loss in losses.items(): # Compute partial derivative of the less with respect to the input. grad, = autograd.grad(loss, [input], retain_graph=True) if self.per_batch_item: # We do not average the gradient over the batch dimension. dims = tuple(range(1, grad.dim())) norm = grad.norm(dim=dims, p=2).mean() else: norm = grad.norm(p=2) norms[name] = norm grads[name] = grad count = 1 if self.per_batch_item: count = len(grad) # Average norms across workers. Theoretically we should average the # squared norm, then take the sqrt, but it worked fine like that. avg_norms = flashy.distrib.average_metrics(self.averager(norms), count) # We approximate the total norm of the gradient as the sums of the norms. # Obviously this can be very incorrect if all gradients are aligned, but it works fine. total = sum(avg_norms.values()) self._metrics = {} if self.monitor: # Store the ratio of the total gradient represented by each loss. for k, v in avg_norms.items(): self._metrics[f'ratio_{k}'] = v / total total_weights = sum([self.weights[k] for k in avg_norms]) assert total_weights > 0. desired_ratios = {k: w / total_weights for k, w in self.weights.items()} out_grad = torch.zeros_like(input) effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype) for name, avg_norm in avg_norms.items(): if self.balance_grads: # g_balanced = g / avg(||g||) * total_norm * desired_ratio scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm) else: # We just do regular weighted sum of the gradients. scale = self.weights[name] out_grad.add_(grads[name], alpha=scale) effective_loss += scale * losses[name].detach() # Send the computed partial derivative with respect to the output of the model to the model. input.backward(out_grad) return effective_loss ================================================ FILE: audiocraft/losses/sisnr.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import typing as tp import torch from torch import nn from torch.nn import functional as F def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: """Given input of size [*OT, T], output Tensor of size [*OT, F, K] with K the kernel size, by extracting frames with the given stride. This will pad the input so that `F = ceil(T / K)`. see https://github.com/pytorch/pytorch/issues/60466 """ *shape, length = a.shape n_frames = math.ceil(length / stride) tgt_length = (n_frames - 1) * stride + kernel_size a = F.pad(a, (0, tgt_length - length)) strides = list(a.stride()) assert strides[-1] == 1, "data should be contiguous" strides = strides[:-1] + [stride, 1] return a.as_strided([*shape, n_frames, kernel_size], strides) def _center(x: torch.Tensor) -> torch.Tensor: return x - x.mean(-1, True) def _norm2(x: torch.Tensor) -> torch.Tensor: return x.pow(2).sum(-1, True) class SISNR(nn.Module): """SISNR loss. Input should be [B, C, T], output is scalar. Args: sample_rate (int): Sample rate. segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on entire audio only. overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. epsilon (float): Epsilon value for numerical stability. """ def __init__( self, sample_rate: int = 16000, segment: tp.Optional[float] = 20, overlap: float = 0.5, epsilon: float = torch.finfo(torch.float32).eps, ): super().__init__() self.sample_rate = sample_rate self.segment = segment self.overlap = overlap self.epsilon = epsilon def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: B, C, T = ref_sig.shape assert ref_sig.shape == out_sig.shape if self.segment is None: frame = T stride = T else: frame = int(self.segment * self.sample_rate) stride = int(frame * (1 - self.overlap)) epsilon = self.epsilon * frame # make epsilon prop to frame size. gt = _unfold(ref_sig, frame, stride) est = _unfold(out_sig, frame, stride) if self.segment is None: assert gt.shape[-1] == 1 gt = _center(gt) est = _center(est) dot = torch.einsum("bcft,bcft->bcf", gt, est) proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt)) noise = est - proj sisnr = 10 * ( torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise)) ) return -1 * sisnr[..., 0].mean() ================================================ FILE: audiocraft/losses/specloss.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import numpy as np from torchaudio.transforms import MelSpectrogram import torch from torch import nn from torch.nn import functional as F from ..modules import pad_for_conv1d class MelSpectrogramWrapper(nn.Module): """Wrapper around MelSpectrogram torchaudio transform providing proper padding and additional post-processing including log scaling. Args: n_mels (int): Number of mel bins. n_fft (int): Number of fft. hop_length (int): Hop size. win_length (int): Window length. n_mels (int): Number of mel bins. sample_rate (int): Sample rate. f_min (float or None): Minimum frequency. f_max (float or None): Maximum frequency. log (bool): Whether to scale with log. normalized (bool): Whether to normalize the melspectrogram. floor_level (float): Floor level based on human perception (default=1e-5). """ def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None, n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None, log: bool = True, normalized: bool = False, floor_level: float = 1e-5): super().__init__() self.n_fft = n_fft hop_length = int(hop_length) self.hop_length = hop_length self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized, window_fn=torch.hann_window, center=False) self.floor_level = floor_level self.log = log def forward(self, x): p = int((self.n_fft - self.hop_length) // 2) if len(x.shape) == 2: x = x.unsqueeze(1) x = F.pad(x, (p, p), "reflect") # Make sure that all the frames are full. # The combination of `pad_for_conv1d` and the above padding # will make the output of size ceil(T / hop). x = pad_for_conv1d(x, self.n_fft, self.hop_length) self.mel_transform.to(x.device) mel_spec = self.mel_transform(x) B, C, freqs, frame = mel_spec.shape if self.log: mel_spec = torch.log10(self.floor_level + mel_spec) return mel_spec.reshape(B, C * freqs, frame) class MelSpectrogramL1Loss(torch.nn.Module): """L1 Loss on MelSpectrogram. Args: sample_rate (int): Sample rate. n_fft (int): Number of fft. hop_length (int): Hop size. win_length (int): Window length. n_mels (int): Number of mel bins. f_min (float or None): Minimum frequency. f_max (float or None): Maximum frequency. log (bool): Whether to scale with log. normalized (bool): Whether to normalize the melspectrogram. floor_level (float): Floor level value based on human perception (default=1e-5). """ def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None, log: bool = True, normalized: bool = False, floor_level: float = 1e-5): super().__init__() self.l1 = torch.nn.L1Loss() self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, log=log, normalized=normalized, floor_level=floor_level) def forward(self, x, y): self.melspec.to(x.device) s_x = self.melspec(x) s_y = self.melspec(y) return self.l1(s_x, s_y) class MultiScaleMelSpectrogramLoss(nn.Module): """Multi-Scale spectrogram loss (msspec). Args: sample_rate (int): Sample rate. range_start (int): Power of 2 to use for the first scale. range_stop (int): Power of 2 to use for the last scale. n_mels (int): Number of mel bins. f_min (float): Minimum frequency. f_max (float or None): Maximum frequency. normalized (bool): Whether to normalize the melspectrogram. alphas (bool): Whether to use alphas as coefficients or not. floor_level (float): Floor level value based on human perception (default=1e-5). """ def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11, n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None, normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5): super().__init__() l1s = list() l2s = list() self.alphas = list() self.total = 0 self.normalized = normalized for i in range(range_start, range_end): l1s.append( MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, log=False, normalized=normalized, floor_level=floor_level)) l2s.append( MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, log=True, normalized=normalized, floor_level=floor_level)) if alphas: self.alphas.append(np.sqrt(2 ** i - 1)) else: self.alphas.append(1) self.total += self.alphas[-1] + 1 self.l1s = nn.ModuleList(l1s) self.l2s = nn.ModuleList(l2s) def forward(self, x, y): loss = 0.0 self.l1s.to(x.device) self.l2s.to(x.device) for i in range(len(self.alphas)): s_x_1 = self.l1s[i](x) s_y_1 = self.l1s[i](y) s_x_2 = self.l2s[i](x) s_y_2 = self.l2s[i](y) loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2) if self.normalized: loss = loss / self.total return loss ================================================ FILE: audiocraft/losses/stftloss.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Adapted from MIT code under the original license # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) import typing as tp import torch from torch import nn from torch.nn import functional as F # TODO: Replace with torchaudio.STFT? def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int, window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor: """Perform STFT and convert to magnitude spectrogram. Args: x: Input signal tensor (B, C, T). fft_size (int): FFT size. hop_length (int): Hop size. win_length (int): Window length. window (torch.Tensor or None): Window function type. normalized (bool): Whether to normalize the STFT or not. Returns: torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1). """ B, C, T = x.shape x_stft = torch.stft( x.view(-1, T), fft_size, hop_length, win_length, window, normalized=normalized, return_complex=True, ) x_stft = x_stft.view(B, C, *x_stft.shape[1:]) real = x_stft.real imag = x_stft.imag # NOTE(kan-bayashi): clamp is needed to avoid nan or inf return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) class SpectralConvergenceLoss(nn.Module): """Spectral convergence loss. """ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): super().__init__() self.epsilon = epsilon def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): """Calculate forward propagation. Args: x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). Returns: torch.Tensor: Spectral convergence loss value. """ return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon) class LogSTFTMagnitudeLoss(nn.Module): """Log STFT magnitude loss. Args: epsilon (float): Epsilon value for numerical stability. """ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps): super().__init__() self.epsilon = epsilon def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor): """Calculate forward propagation. Args: x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). Returns: torch.Tensor: Log STFT magnitude loss value. """ return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag)) class STFTLosses(nn.Module): """STFT losses. Args: n_fft (int): Size of FFT. hop_length (int): Hop length. win_length (int): Window length. window (str): Window function type. normalized (bool): Whether to use normalized STFT or not. epsilon (float): Epsilon for numerical stability. """ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, window: str = "hann_window", normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps): super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length self.normalized = normalized self.register_buffer("window", getattr(torch, window)(win_length)) self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon) self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon) def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: x (torch.Tensor): Predicted signal (B, T). y (torch.Tensor): Groundtruth signal (B, T). Returns: torch.Tensor: Spectral convergence loss value. torch.Tensor: Log STFT magnitude loss value. """ x_mag = _stft(x, self.n_fft, self.hop_length, self.win_length, self.window, self.normalized) # type: ignore y_mag = _stft(y, self.n_fft, self.hop_length, self.win_length, self.window, self.normalized) # type: ignore sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) return sc_loss, mag_loss class STFTLoss(nn.Module): """Single Resolution STFT loss. Args: n_fft (int): Nb of FFT. hop_length (int): Hop length. win_length (int): Window length. window (str): Window function type. normalized (bool): Whether to use normalized STFT or not. epsilon (float): Epsilon for numerical stability. factor_sc (float): Coefficient for the spectral loss. factor_mag (float): Coefficient for the magnitude loss. """ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600, window: str = "hann_window", normalized: bool = False, factor_sc: float = 0.1, factor_mag: float = 0.1, epsilon: float = torch.finfo(torch.float32).eps): super().__init__() self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon) self.factor_sc = factor_sc self.factor_mag = factor_mag def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: x (torch.Tensor): Predicted signal (B, T). y (torch.Tensor): Groundtruth signal (B, T). Returns: torch.Tensor: Single resolution STFT loss. """ sc_loss, mag_loss = self.loss(x, y) return self.factor_sc * sc_loss + self.factor_mag * mag_loss class MRSTFTLoss(nn.Module): """Multi resolution STFT loss. Args: n_ffts (Sequence[int]): Sequence of FFT sizes. hop_lengths (Sequence[int]): Sequence of hop sizes. win_lengths (Sequence[int]): Sequence of window lengths. window (str): Window function type. factor_sc (float): Coefficient for the spectral loss. factor_mag (float): Coefficient for the magnitude loss. normalized (bool): Whether to use normalized STFT or not. epsilon (float): Epsilon for numerical stability. """ def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50], win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window", factor_sc: float = 0.1, factor_mag: float = 0.1, normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps): super().__init__() assert len(n_ffts) == len(hop_lengths) == len(win_lengths) self.stft_losses = torch.nn.ModuleList() for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths): self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)] self.factor_sc = factor_sc self.factor_mag = factor_mag def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Calculate forward propagation. Args: x (torch.Tensor): Predicted signal (B, T). y (torch.Tensor): Groundtruth signal (B, T). Returns: torch.Tensor: Multi resolution STFT loss. """ sc_loss = torch.Tensor([0.0]) mag_loss = torch.Tensor([0.0]) for f in self.stft_losses: sc_l, mag_l = f(x, y) sc_loss += sc_l mag_loss += mag_l sc_loss /= len(self.stft_losses) mag_loss /= len(self.stft_losses) return self.factor_sc * sc_loss + self.factor_mag * mag_loss ================================================ FILE: audiocraft/metrics/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc. """ # flake8: noqa from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric from .chroma_cosinesim import ChromaCosineSimilarityMetric from .fad import FrechetAudioDistanceMetric from .kld import KLDivergenceMetric, PasstKLDivergenceMetric from .rvm import RelativeVolumeMel from .visqol import ViSQOL ================================================ FILE: audiocraft/metrics/chroma_cosinesim.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torchmetrics from ..data.audio_utils import convert_audio from ..modules.chroma import ChromaExtractor class ChromaCosineSimilarityMetric(torchmetrics.Metric): """Chroma cosine similarity metric. This metric extracts a chromagram for a reference waveform and a generated waveform and compares each frame using the cosine similarity function. The output is the mean cosine similarity. Args: sample_rate (int): Sample rate used by the chroma extractor. n_chroma (int): Number of chroma used by the chroma extractor. radix2_exp (int): Exponent for the chroma extractor. argmax (bool): Whether the chroma extractor uses argmax. eps (float): Epsilon for cosine similarity computation. """ def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8): super().__init__() self.chroma_sample_rate = sample_rate self.n_chroma = n_chroma self.eps = eps self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma, radix2_exp=radix2_exp, argmax=argmax) self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Compute cosine similarity between chromagrams and accumulate scores over the dataset.""" if preds.size(0) == 0: return assert preds.shape == targets.shape, ( f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}") assert preds.size(0) == sizes.size(0), ( f"Number of items in preds ({preds.shape}) mismatch ", f"with sizes ({sizes.shape})") assert preds.size(0) == sample_rates.size(0), ( f"Number of items in preds ({preds.shape}) mismatch ", f"with sample_rates ({sample_rates.shape})") assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch" device = self.weight.device preds, targets = preds.to(device), targets.to(device) # type: ignore sample_rate = sample_rates[0].item() preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) gt_chroma = self.chroma_extractor(targets) gen_chroma = self.chroma_extractor(preds) chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int() for i in range(len(gt_chroma)): t = int(chroma_lens[i].item()) cosine_sim = torch.nn.functional.cosine_similarity( gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps) self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore self.weight += torch.tensor(t) # type: ignore def compute(self) -> float: """Computes the average cosine similarty across all generated/target chromagrams pairs.""" assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore return (self.cosine_sum / self.weight).item() # type: ignore ================================================ FILE: audiocraft/metrics/clap_consistency.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path import typing as tp import torch import torchmetrics from transformers import RobertaTokenizer # type: ignore from ..data.audio_utils import convert_audio from ..environment import AudioCraftEnvironment from ..utils.utils import load_clap_state_dict try: import laion_clap # type: ignore except ImportError: laion_clap = None class TextConsistencyMetric(torchmetrics.Metric): """Text consistency metric measuring consistency between audio and text pairs.""" def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: raise NotImplementedError("implement how to update the metric from the audio and text pairs.") def compute(self): raise NotImplementedError("implement how to compute the final metric score.") class CLAPTextConsistencyMetric(TextConsistencyMetric): """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as well as the generated audio based on them, and define the MCC metric as the average cosine similarity between these embeddings. Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP """ def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): super().__init__() if laion_clap is None: raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") self._initialize_model(model_path, model_arch, enable_fusion) def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): model_path = AudioCraftEnvironment.resolve_reference_path(model_path) self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) self.model_sample_rate = 48_000 load_clap_state_dict(self.model, model_path) self.model.eval() def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: # we use the default params from CLAP module here as well return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" assert audio.size(0) == len(text), "Number of audio and text samples should match" assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" sample_rate = int(sample_rates[0].item()) # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) # cosine similarity between the text and the audio embedding cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) self.cosine_sum += cosine_sim.sum(dim=0) self.weight += torch.tensor(cosine_sim.size(0)) def compute(self): """Computes the average cosine similarty across all audio/text pairs.""" assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore return (self.cosine_sum / self.weight).item() # type: ignore ================================================ FILE: audiocraft/metrics/fad.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from pathlib import Path import os import subprocess import tempfile import typing as tp from audiocraft.data.audio import audio_write from audiocraft.data.audio_utils import convert_audio import flashy import torch import torchmetrics from ..environment import AudioCraftEnvironment logger = logging.getLogger(__name__) VGGISH_SAMPLE_RATE = 16_000 VGGISH_CHANNELS = 1 class FrechetAudioDistanceMetric(torchmetrics.Metric): """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research. From: D.C. Dowson & B.V. Landau The Fréchet distance between multivariate normal distributions https://doi.org/10.1016/0047-259X(82)90077-X The Fréchet distance between two multivariate gaussians, `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`. d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y)) = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y) - 2 * Tr(sqrt(sigma_x*sigma_y))) To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance We provide the below instructions as reference but we do not guarantee for further support in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0. We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda). 1. Get the code and models following the repository instructions. We used the steps below: git clone git@github.com:google-research/google-research.git git clone git@github.com:tensorflow/models.git mkdir google-research/tensorflow_models touch google-research/tensorflow_models/__init__.py cp -r models/research/audioset google-research/tensorflow_models/ touch google-research/tensorflow_models/audioset/__init__.py echo "from .vggish import mel_features, vggish_params, vggish_slim" > \ google-research/tensorflow_models/audioset/__init__.py # we can now remove the tensorflow models repository # rm -r models cd google-research Follow the instructions to download the vggish checkpoint. AudioCraft base configuration assumes it is placed in the AudioCraft reference dir. Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3: - Update xrange for range in: https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py - Update `import vggish_params as params` to `from . import vggish_params as params` in: https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py - Add flag to provide a given batch size for running the AudioSet model in: https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py ``` flags.DEFINE_integer('batch_size', 64, 'Number of samples in the batch for AudioSet model.') ``` Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding: `batch_size=FLAGS.batch_size` to the provided parameters. 2. Follow instructions for the library installation and a valid TensorFlow installation ``` # e.g. instructions from: https://www.tensorflow.org/install/pip conda install -c conda-forge cudatoolkit=11.8.0 python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.* mkdir -p $CONDA_PREFIX/etc/conda/activate.d echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh # Verify install: on a machine with GPU device python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))" ``` Now install frechet_audio_distance required dependencies: ``` # We assume we already have TensorFlow installed from the above steps pip install apache-beam numpy scipy tf_slim ``` Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup (you may want to specify --model_ckpt flag pointing to the model's path). 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable and Tensorflow library path from the above installation steps: export TF_PYTHON_EXE="" export TF_LIBRARY_PATH="" e.g. assuming we have installed everything in a dedicated conda env with python 3.10 that is currently active: export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python" export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib" Finally you may want to export the following variable: export TF_FORCE_GPU_ALLOW_GROWTH=true See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth You can save those environment variables in your training conda env, when currently active: `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh` e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval, and the training conda env is named audiocraft: ``` # activate training env conda activate audiocraft # get path to all envs CONDA_ENV_DIR=$(dirname $CONDA_PREFIX) # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh # optionally: echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh # you may need to reactivate the audiocraft env for this to take effect ``` Args: bin (Path or str): Path to installed frechet audio distance code. model_path (Path or str): Path to Tensorflow checkpoint for the model used to compute statistics over the embedding beams. format (str): Audio format used to save files. log_folder (Path or str, optional): Path where to write process logs. """ def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str], format: str = "wav", batch_size: tp.Optional[int] = None, log_folder: tp.Optional[tp.Union[Path, str]] = None): super().__init__() self.model_sample_rate = VGGISH_SAMPLE_RATE self.model_channels = VGGISH_CHANNELS self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path) assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}" self.format = format self.batch_size = batch_size self.bin = bin self.tf_env = {"PYTHONPATH": str(self.bin)} self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python' logger.info("Python exe for TF is %s", self.python_path) if 'TF_LIBRARY_PATH' in os.environ: self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH'] if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ: self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] logger.info("Env for TF is %r", self.tf_env) self.reset(log_folder) self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum") def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None): """Reset torchmetrics.Metrics state.""" log_folder = Path(log_folder or tempfile.mkdtemp()) self.tmp_dir = log_folder / 'fad' self.tmp_dir.mkdir(exist_ok=True) self.samples_tests_dir = self.tmp_dir / 'tests' self.samples_tests_dir.mkdir(exist_ok=True) self.samples_background_dir = self.tmp_dir / 'background' self.samples_background_dir.mkdir(exist_ok=True) self.manifest_tests = self.tmp_dir / 'files_tests.cvs' self.manifest_background = self.tmp_dir / 'files_background.cvs' self.stats_tests_dir = self.tmp_dir / 'stats_tests' self.stats_background_dir = self.tmp_dir / 'stats_background' self.counter = 0 def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor, stems: tp.Optional[tp.List[str]] = None): """Update torchmetrics.Metrics by saving the audio and updating the manifest file.""" assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}" num_samples = preds.shape[0] assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0) assert stems is None or num_samples == len(set(stems)) for i in range(num_samples): self.total_files += 1 # type: ignore self.counter += 1 wav_len = int(sizes[i].item()) sample_rate = int(sample_rates[i].item()) pred_wav = preds[i] target_wav = targets[i] pred_wav = pred_wav[..., :wav_len] target_wav = target_wav[..., :wav_len] stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}' # dump audio files try: pred_wav = convert_audio( pred_wav.unsqueeze(0), from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).squeeze(0) audio_write( self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate, format=self.format, strategy="peak") except Exception as e: logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}") try: # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying # the original audio when writing it target_wav = convert_audio( target_wav.unsqueeze(0), from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).squeeze(0) audio_write( self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate, format=self.format, strategy="peak") except Exception as e: logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}") def _get_samples_name(self, is_background: bool): return 'background' if is_background else 'tests' def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None): if is_background: input_samples_dir = self.samples_background_dir input_filename = self.manifest_background stats_name = self.stats_background_dir else: input_samples_dir = self.samples_tests_dir input_filename = self.manifest_tests stats_name = self.stats_tests_dir beams_name = self._get_samples_name(is_background) log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log' logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}") with open(input_filename, "w") as fout: for path in Path(input_samples_dir).glob(f"*.{self.format}"): fout.write(f"{str(path)}\n") cmd = [ self.python_path, "-m", "frechet_audio_distance.create_embeddings_main", "--model_ckpt", f"{self.model_path}", "--input_files", f"{str(input_filename)}", "--stats", f"{str(stats_name)}", ] if self.batch_size is not None: cmd += ["--batch_size", str(self.batch_size)] logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}") env = os.environ if gpu_index is not None: env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) process = subprocess.Popen( cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT) return process, log_file def _compute_fad_score(self, gpu_index: tp.Optional[int] = None): cmd = [ self.python_path, "-m", "frechet_audio_distance.compute_fad", "--test_stats", f"{str(self.stats_tests_dir)}", "--background_stats", f"{str(self.stats_background_dir)}", ] logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}") env = os.environ if gpu_index is not None: env["CUDA_VISIBLE_DEVICES"] = str(gpu_index) result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True) if result.returncode: logger.error( "Error with FAD computation from stats: \n %s \n %s", result.stdout.decode(), result.stderr.decode() ) raise RuntimeError("Error while executing FAD computation from stats") try: # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more fad_score = float(result.stdout[4:]) return fad_score except Exception as e: raise RuntimeError(f"Error parsing FAD score from command stdout: {e}") def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None: beams_name = self._get_samples_name(is_background) if returncode: with open(log_file, "r") as f: error_log = f.read() logger.error(error_log) os._exit(1) else: logger.info(f"Successfully computed embedding beams on {beams_name} samples.") def _parallel_create_embedding_beams(self, num_of_gpus: int): assert num_of_gpus > 0 logger.info("Creating embeddings beams in a parallel manner on different GPUs") tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0) bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1) tests_beams_code = tests_beams_process.wait() bg_beams_code = bg_beams_process.wait() self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) def _sequential_create_embedding_beams(self): logger.info("Creating embeddings beams in a sequential manner") tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False) tests_beams_code = tests_beams_process.wait() self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False) bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True) bg_beams_code = bg_beams_process.wait() self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True) @flashy.distrib.rank_zero_only def _local_compute_frechet_audio_distance(self): """Compute Frechet Audio Distance score calling TensorFlow API.""" num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 if num_of_gpus > 1: self._parallel_create_embedding_beams(num_of_gpus) else: self._sequential_create_embedding_beams() fad_score = self._compute_fad_score(gpu_index=0) return fad_score def compute(self) -> float: """Compute metrics.""" assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore fad_score = self._local_compute_frechet_audio_distance() logger.warning(f"FAD score = {fad_score}") fad_score = flashy.distrib.broadcast_object(fad_score, src=0) return fad_score ================================================ FILE: audiocraft/metrics/kld.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import contextlib from functools import partial import logging import os import typing as tp import torch import torchmetrics from ..data.audio_utils import convert_audio logger = logging.getLogger(__name__) class _patch_passt_stft: """Decorator to patch torch.stft in PaSST.""" def __init__(self): self.old_stft = torch.stft def __enter__(self): # return_complex is a mandatory parameter in latest torch versions # torch is throwing RuntimeErrors when not set torch.stft = partial(torch.stft, return_complex=False) def __exit__(self, *exc): torch.stft = self.old_stft def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: """Computes the elementwise KL-Divergence loss between probability distributions from generated samples and target samples. Args: pred_probs (torch.Tensor): Probabilities for each label obtained from a classifier on generated audio. Expected shape is [B, num_classes]. target_probs (torch.Tensor): Probabilities for each label obtained from a classifier on target audio. Expected shape is [B, num_classes]. epsilon (float): Epsilon value. Returns: kld (torch.Tensor): KLD loss between each generated sample and target pair. """ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none") return kl_div.sum(-1) class KLDivergenceMetric(torchmetrics.Metric): """Base implementation for KL Divergence metric. The KL divergence is measured between probability distributions of class predictions returned by a pre-trained audio classification model. When the KL-divergence is low, the generated audio is expected to have similar acoustic characteristics as the reference audio, according to the classifier. """ def __init__(self): super().__init__() self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum") def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: """Get model output given provided input tensor. Args: x (torch.Tensor): Input audio tensor of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. Returns: probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes]. """ raise NotImplementedError("implement method to extract label distributions from the model.") def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Calculates running KL-Divergence loss between batches of audio preds (generated) and target (ground-truth) Args: preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. """ assert preds.shape == targets.shape assert preds.size(0) > 0, "Cannot update the loss with empty tensors" preds_probs = self._get_label_distribution(preds, sizes, sample_rates) targets_probs = self._get_label_distribution(targets, sizes, sample_rates) if preds_probs is not None and targets_probs is not None: assert preds_probs.shape == targets_probs.shape kld_scores = kl_divergence(preds_probs, targets_probs) assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" self.kld_pq_sum += torch.sum(kld_scores) kld_qp_scores = kl_divergence(targets_probs, preds_probs) self.kld_qp_sum += torch.sum(kld_qp_scores) self.weight += torch.tensor(kld_scores.size(0)) def compute(self) -> dict: """Computes KL-Divergence across all evaluated pred/target pairs.""" weight: float = float(self.weight.item()) # type: ignore assert weight > 0, "Unable to compute with total number of comparisons <= 0" logger.info(f"Computing KL divergence on a total of {weight} samples") kld_pq = self.kld_pq_sum.item() / weight # type: ignore kld_qp = self.kld_qp_sum.item() / weight # type: ignore kld_both = kld_pq + kld_qp return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both} class PasstKLDivergenceMetric(KLDivergenceMetric): """KL-Divergence metric based on pre-trained PASST classifier on AudioSet. From: PaSST: Efficient Training of Audio Transformers with Patchout Paper: https://arxiv.org/abs/2110.05069 Implementation: https://github.com/kkoutini/PaSST Follow instructions from the github repo: ``` pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' ``` Args: pretrained_length (float, optional): Audio duration used for the pretrained model. """ def __init__(self, pretrained_length: tp.Optional[float] = None): super().__init__() self._initialize_model(pretrained_length) def _initialize_model(self, pretrained_length: tp.Optional[float] = None): """Initialize underlying PaSST audio classifier.""" model, sr, max_frames, min_frames = self._load_base_model(pretrained_length) self.min_input_frames = min_frames self.max_input_frames = max_frames self.model_sample_rate = sr self.model = model self.model.eval() self.model.to(self.device) def _load_base_model(self, pretrained_length: tp.Optional[float]): """Load pretrained model from PaSST.""" try: if pretrained_length == 30: from hear21passt.base30sec import get_basic_model # type: ignore max_duration = 30 elif pretrained_length == 20: from hear21passt.base20sec import get_basic_model # type: ignore max_duration = 20 else: from hear21passt.base import get_basic_model # type: ignore # Original PASST was trained on AudioSet with 10s-long audio samples max_duration = 10 min_duration = 0.15 min_duration = 0.15 except ModuleNotFoundError: raise ModuleNotFoundError( "Please install hear21passt to compute KL divergence: ", "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'" ) model_sample_rate = 32_000 max_input_frames = int(max_duration * model_sample_rate) min_input_frames = int(min_duration * model_sample_rate) with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): model = get_basic_model(mode='logits') return model, model_sample_rate, max_input_frames, min_input_frames def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]: """Process audio to feed to the pretrained model.""" wav = wav.unsqueeze(0) wav = wav[..., :wav_len] wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1) wav = wav.squeeze(0) # we don't pad but return a list of audio segments as this otherwise affects the KLD computation segments = torch.split(wav, self.max_input_frames, dim=-1) valid_segments = [] for s in segments: # ignoring too small segments that are breaking the model inference if s.size(-1) > self.min_input_frames: valid_segments.append(s) return [s[None] for s in valid_segments] def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor: """Run the pretrained model and get the predictions.""" assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}" wav = wav.mean(dim=1) # PaSST is printing a lot of garbage that we are not interested in with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): with torch.no_grad(), _patch_passt_stft(): logits = self.model(wav.to(self.device)) probs = torch.softmax(logits, dim=-1) return probs def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: """Get model output given provided input tensor. Args: x (torch.Tensor): Input audio tensor of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. Returns: probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes]. """ all_probs: tp.List[torch.Tensor] = [] for i, wav in enumerate(x): sample_rate = int(sample_rates[i].item()) wav_len = int(sizes[i].item()) wav_segments = self._process_audio(wav, sample_rate, wav_len) for segment in wav_segments: probs = self._get_model_preds(segment).mean(dim=0) all_probs.append(probs) if len(all_probs) > 0: return torch.stack(all_probs, dim=0) else: return None ================================================ FILE: audiocraft/metrics/rvm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import torch from torch import nn import torchaudio def db_to_scale(volume: tp.Union[float, torch.Tensor]): return 10 ** (volume / 20) def scale_to_db(scale: torch.Tensor, min_volume: float = -120): min_scale = db_to_scale(min_volume) return 20 * torch.log10(scale.clamp(min=min_scale)) class RelativeVolumeMel(nn.Module): """Relative volume melspectrogram measure. Computes a measure of distance over two mel spectrogram that is interpretable in terms of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will first renormalize both by the ground truth of `x_ref`. ..Warning:: This class returns the volume of the distortion at the spectrogram level, e.g. low negative values reflects lower distortion levels. For a SNR (like reported in the MultiBandDiffusion paper), just take `-rvm`. Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g. clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`) with the goal of avoiding the loss being dominated by parts where the reference is almost silent. Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely good (for a neural network output, although sound engineers typically aim for much lower attenuations). Similarly, anything above +30 dB would just be completely missing the target, and there is no point in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more in line with what neural nets currently can achieve. For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between the target and reference mel-spec is 10 dB lower than the reference mel-spec value. The metric can be aggregated over a given frequency band in order have different insights for different region of the spectrum. `num_aggregated_bands` controls the number of bands. ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it is numerically stable when computing its gradient. We thus advise against using it as a training loss. Args: sample_rate (int): Sample rate of the input audio. n_mels (int): Number of mel bands to use. n_fft (int): Number of frequency bins for the STFT. hop_length (int): Hop length of the STFT and the mel-spectrogram. min_relative_volume (float): The error `z_ref - z_est` volume is given relative to the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped. max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that. max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain to that amount, to avoid rescaling near silence. Given in dB. min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram, and anything below that will be considered equally. num_aggregated_bands (int): Number of bands to keep when computing the average RVM value. For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs. """ def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512, hop_length: int = 128, min_relative_volume: float = -25, max_relative_volume: float = 25, max_initial_gain: float = 25, min_activity_volume: float = -25, num_aggregated_bands: int = 4) -> None: super().__init__() self.melspec = torchaudio.transforms.MelSpectrogram( n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, normalized=True, sample_rate=sample_rate, power=2) self.min_relative_volume = min_relative_volume self.max_relative_volume = max_relative_volume self.max_initial_gain = max_initial_gain self.min_activity_volume = min_activity_volume self.num_aggregated_bands = num_aggregated_bands def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]: """Compute RVM metric between estimate and reference samples. Args: estimate (torch.Tensor): Estimate sample. ground_truth (torch.Tensor): Reference sample. Returns: dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}` for the RVM over the k-th band (k=0..num_aggregated_bands - 1). """ min_scale = db_to_scale(-self.max_initial_gain) std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale) z_gt = self.melspec(ground_truth / std).sqrt() z_est = self.melspec(estimate / std).sqrt() delta = z_gt - z_est ref_db = scale_to_db(z_gt, self.min_activity_volume) delta_db = scale_to_db(delta.abs(), min_volume=-120) relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume) dims = list(range(relative_db.dim())) dims.remove(dims[-2]) losses_per_band = relative_db.mean(dim=dims) aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)] metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)} metrics['rvm'] = losses_per_band.mean() return metrics ================================================ FILE: audiocraft/metrics/visqol.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import csv import json import logging from pathlib import Path import tempfile import typing as tp import subprocess import shutil import torch import torchaudio logger = logging.getLogger(__name__) class ViSQOL: """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary. To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the instructions available in the open source repository: https://github.com/google/visqol ViSQOL is capable of running in two modes: Audio Mode: When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz. Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. Audio mode uses support vector regression, with the maximum range at ~4.75. Speech Mode: When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz. Input should be resampled to 16kHz. As part of the speech mode processing, a root mean square implementation for voice activity detection is performed on the reference signal to determine what parts of the signal have voice activity and should therefore be included in the comparison. The signal is normalized before performing the voice activity detection. Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison. Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior. For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input Args: visqol_bin (str): Path to the ViSQOL binary. mode (str): ViSQOL computation mode, expecting "audio" or "speech". model (str): Name of the model to use for similarity to quality model. debug (bool): Whether to also get debug metrics from ViSQOL or not. """ SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000} ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values()) def __init__(self, bin: tp.Union[Path, str], mode: str = "audio", model: str = "libsvm_nu_svr_model.txt", debug: bool = False): assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}" self.visqol_bin = str(bin) self.visqol_mode = mode self.target_sr = self._get_target_sr(self.visqol_mode) self.model = model self.debug = debug assert Path(self.visqol_model).exists(), \ f"Could not find the specified model in ViSQOL install: {self.visqol_model}" def _get_target_sr(self, mode: str) -> int: # returns target sampling rate for the corresponding ViSQOL mode. if mode not in ViSQOL.SAMPLE_RATES_MODES: raise ValueError( f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}" ) return ViSQOL.SAMPLE_RATES_MODES[mode] def _prepare_files( self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False ): # prepare files for ViSQOL evaluation. assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES assert len(ref_sig) == len(deg_sig), ( "Expects same number of ref and degraded inputs", f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}" ) # resample audio if needed if sr != target_sr: transform = torchaudio.transforms.Resample(sr, target_sr) pad = int(0.5 * target_sr) rs_ref = [] rs_deg = [] for i in range(len(ref_sig)): rs_ref_i = transform(ref_sig[i]) rs_deg_i = transform(deg_sig[i]) if pad_with_silence: rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0) rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0) rs_ref.append(rs_ref_i) rs_deg.append(rs_deg_i) ref_sig = torch.stack(rs_ref) deg_sig = torch.stack(rs_deg) # save audio chunks to tmp dir and create csv tmp_dir = Path(tempfile.mkdtemp()) try: tmp_input_csv_path = tmp_dir / "input.csv" tmp_results_csv_path = tmp_dir / "results.csv" tmp_debug_json_path = tmp_dir / "debug.json" with open(tmp_input_csv_path, "w") as csv_file: csv_writer = csv.writer(csv_file) csv_writer.writerow(["reference", "degraded"]) for i in range(len(ref_sig)): tmp_ref_filename = tmp_dir / f"ref_{i}.wav" tmp_deg_filename = tmp_dir / f"deg_{i}.wav" torchaudio.save( tmp_ref_filename, torch.clamp(ref_sig[i], min=-0.99, max=0.99), sample_rate=target_sr, bits_per_sample=16, encoding="PCM_S" ) torchaudio.save( tmp_deg_filename, torch.clamp(deg_sig[i], min=-0.99, max=0.99), sample_rate=target_sr, bits_per_sample=16, encoding="PCM_S" ) csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)]) return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path except Exception as e: logger.error("Exception occurred when preparing files for ViSQOL: %s", e) return tmp_dir, None, None, None def _flush_files(self, tmp_dir: tp.Union[Path, str]): # flush tmp files used to compute ViSQOL. shutil.rmtree(str(tmp_dir)) def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float: # collect results for each evaluated pair and return averaged moslqo score. with open(results_csv_path, "r") as csv_file: reader = csv.DictReader(csv_file) moslqo_scores = [float(row["moslqo"]) for row in reader] if len(moslqo_scores) > 0: return sum(moslqo_scores) / len(moslqo_scores) else: return 0.0 def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict: # collect debug data for the visqol inference. with open(debug_json_path, "r") as f: data = json.load(f) return data @property def visqol_model(self): return f'{self.visqol_bin}/model/{self.model}' def _run_visqol( self, input_csv_path: tp.Union[Path, str], results_csv_path: tp.Union[Path, str], debug_csv_path: tp.Optional[tp.Union[Path, str]], ): input_csv_path = str(input_csv_path) results_csv_path = str(results_csv_path) debug_csv_path = str(debug_csv_path) cmd = [ f'{self.visqol_bin}/bazel-bin/visqol', '--batch_input_csv', f'{input_csv_path}', '--results_csv', f'{results_csv_path}' ] if debug_csv_path is not None: cmd += ['--output_debug', f'{debug_csv_path}'] if self.visqol_mode == "speech": cmd += ['--use_speech_mode'] cmd += ['--similarity_to_quality_model', f'{self.visqol_model}'] result = subprocess.run(cmd, capture_output=True) if result.returncode: logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode()) raise RuntimeError("Error while executing visqol") result.check_returncode() def __call__( self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, pad_with_silence: bool = False, ): """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate. Args: ref_sig (torch.Tensor): Reference signals as [B, C, T]. deg_sig (torch.Tensor): Degraded signals as [B, C, T]. sr (int): Sample rate of the two audio signals. pad_with_silence (bool): Whether to pad the file with silences as recommended in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input). Returns: float: The ViSQOL score or mean score for the batch. """ logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples") tmp_dir, input_csv, results_csv, debug_json = self._prepare_files( ref_sig, deg_sig, sr, self.target_sr, pad_with_silence ) try: if input_csv and results_csv: self._run_visqol( input_csv, results_csv, debug_json if self.debug else None, ) mosqol = self._collect_moslqo_score(results_csv) return mosqol else: raise RuntimeError("Something unexpected happened when running VISQOL!") except Exception as e: logger.error("Exception occurred when running ViSQOL: %s", e) finally: self._flush_files(tmp_dir) ================================================ FILE: audiocraft/models/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. """ # flake8: noqa from . import builders, loaders from .encodec import ( CompressionModel, EncodecModel, DAC, HFEncodecModel, HFEncodecCompressionModel) from .audiogen import AudioGen from .lm import LMModel from .multibanddiffusion import MultiBandDiffusion from .musicgen import MusicGen from .unet import DiffusionUnet ================================================ FILE: audiocraft/models/audiogen.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Main model for using AudioGen. This will combine all the required components and provide easy access to the generation API. """ import typing as tp import torch from .encodec import CompressionModel from .lm import LMModel from .builders import get_debug_compression_model, get_debug_lm_model from .loaders import load_compression_model, load_lm_model from ..data.audio_utils import convert_audio from ..modules.conditioners import ConditioningAttributes from ..utils.autocast import TorchAutocast class AudioGen: """AudioGen main model with convenient generation API. Args: name (str): name of the model. compression_model (CompressionModel): Compression model used to map audio to invertible discrete representations. lm (LMModel): Language model over discrete representations. max_duration (float, optional): maximum duration the model can produce, otherwise, inferred from the training params. """ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = None): self.name = name self.compression_model = compression_model self.lm = lm if max_duration is None: if hasattr(lm, 'cfg'): max_duration = lm.cfg.dataset.segment_duration # type: ignore else: raise ValueError("You must provide max_duration when building directly AudioGen") assert max_duration is not None self.max_duration: float = max_duration self.device = next(iter(lm.parameters())).device self.generation_params: dict = {} self.set_generation_params(duration=5) # 5 seconds by default self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None if self.device.type == 'cpu': self.autocast = TorchAutocast(enabled=False) else: self.autocast = TorchAutocast( enabled=True, device_type=self.device.type, dtype=torch.float16) @property def frame_rate(self) -> float: """Roughly the number of AR steps per seconds.""" return self.compression_model.frame_rate @property def sample_rate(self) -> int: """Sample rate of the generated audio.""" return self.compression_model.sample_rate @property def audio_channels(self) -> int: """Audio channels of the generated audio.""" return self.compression_model.channels @staticmethod def get_pretrained(name: str = 'facebook/audiogen-medium', device=None): """Return pretrained model, we provide a single model for now: - facebook/audiogen-medium (1.5B), text to sound, # see: https://huggingface.co/facebook/audiogen-medium """ if device is None: if torch.cuda.device_count(): device = 'cuda' else: device = 'cpu' if name == 'debug': # used only for unit tests compression_model = get_debug_compression_model(device, sample_rate=16000) lm = get_debug_lm_model(device) return AudioGen(name, compression_model, lm, max_duration=10) compression_model = load_compression_model(name, device=device) lm = load_lm_model(name, device=device) assert 'self_wav' not in lm.condition_provider.conditioners, \ "AudioGen do not support waveform conditioning for now" return AudioGen(name, compression_model, lm) def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 10.0, cfg_coef: float = 3.0, two_step_cfg: bool = False, extend_stride: float = 2): """Set the generation parameters for AudioGen. Args: use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. top_k (int, optional): top_k used for sampling. Defaults to 250. top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. duration (float, optional): Duration of the generated waveform. Defaults to 10.0. cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice. extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much should we extend the audio each time. Larger values will mean less context is preserved, and shorter value will require extra computations. """ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." self.extend_stride = extend_stride self.duration = duration self.generation_params = { 'use_sampling': use_sampling, 'temp': temperature, 'top_k': top_k, 'top_p': top_p, 'cfg_coef': cfg_coef, 'two_step_cfg': two_step_cfg, } def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): """Override the default progress callback.""" self._progress_callback = progress_callback def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor: """Generate samples conditioned on text. Args: descriptions (list of str): A list of strings used as text conditioning. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None return self._generate_tokens(attributes, prompt_tokens, progress) def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, progress: bool = False) -> torch.Tensor: """Generate samples conditioned on audio prompts. Args: prompt (torch.Tensor): A batch of waveforms used for continuation. Prompt should be [B, C, T], or [C, T] if only one sample is generated. prompt_sample_rate (int): Sampling rate of the given audio waveforms. descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ if prompt.dim() == 2: prompt = prompt[None] if prompt.dim() != 3: raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) if descriptions is None: descriptions = [None] * len(prompt) attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) assert prompt_tokens is not None return self._generate_tokens(attributes, prompt_tokens, progress) @torch.no_grad() def _prepare_tokens_and_attributes( self, descriptions: tp.Sequence[tp.Optional[str]], prompt: tp.Optional[torch.Tensor], ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: """Prepare model inputs. Args: descriptions (list of str): A list of strings used as text conditioning. prompt (torch.Tensor): A batch of waveforms used for continuation. """ attributes = [ ConditioningAttributes(text={'description': description}) for description in descriptions] if prompt is not None: if descriptions is not None: assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" prompt = prompt.to(self.device) prompt_tokens, scale = self.compression_model.encode(prompt) assert scale is None else: prompt_tokens = None return attributes, prompt_tokens def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: """Generate discrete audio tokens given audio prompt and/or conditions. Args: attributes (list of ConditioningAttributes): Conditions used for generation (here text). prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. Returns: torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. """ i = 0 prompt_list = attributes[0].text['description'] total_gen_len = int(self.duration * self.frame_rate) max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) current_gen_offset: int = 0 def _progress_callback(generated_tokens: int, tokens_to_generate: int): generated_tokens += current_gen_offset if self._progress_callback is not None: # Note that total_gen_len might be quite wrong depending on the # codebook pattern used, but with delay it is almost accurate. self._progress_callback(generated_tokens, total_gen_len) else: print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') if prompt_tokens is not None: assert max_prompt_len >= prompt_tokens.shape[-1], \ "Prompt is longer than audio to generate" callback = None if progress: callback = _progress_callback if self.duration <= self.max_duration: # generate by sampling from LM, simple case. with self.autocast: attributes[0].text['description'] = prompt_list[0] gen_tokens = self.lm.generate( prompt_tokens, attributes, callback=callback, max_gen_len=total_gen_len, **self.generation_params) else: all_tokens = [] if prompt_tokens is None: prompt_length = 0 else: all_tokens.append(prompt_tokens) prompt_length = prompt_tokens.shape[-1] stride_tokens = int(self.frame_rate * self.extend_stride) while current_gen_offset + prompt_length < total_gen_len: time_offset = current_gen_offset / self.frame_rate chunk_duration = min(self.duration - time_offset, self.max_duration) max_gen_len = int(chunk_duration * self.frame_rate) with self.autocast: if i >= len(prompt_list): i = len(prompt_list) - 1 attributes[0].text['description'] = prompt_list[i] gen_tokens = self.lm.generate( prompt_tokens, attributes, callback=callback, max_gen_len=max_gen_len, **self.generation_params) i = i + 1 if prompt_tokens is None: all_tokens.append(gen_tokens) else: all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) prompt_tokens = gen_tokens[:, :, stride_tokens:] prompt_length = prompt_tokens.shape[-1] current_gen_offset += stride_tokens gen_tokens = torch.cat(all_tokens, dim=-1) # generate audio assert gen_tokens.dim() == 3 with torch.no_grad(): gen_audio = self.compression_model.decode(gen_tokens, None) return gen_audio def to(self, device: str): self.compression_model.to(device) self.lm.to(device) return self ================================================ FILE: audiocraft/models/builders.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ All the functions to build the relevant models and modules from the Hydra config. """ import typing as tp import audiocraft import omegaconf import torch from .encodec import CompressionModel, EncodecModel from .lm import LMModel from ..modules.codebooks_patterns import ( CodebooksPatternProvider, DelayedPatternProvider, MusicLMPattern, ParallelPatternProvider, UnrolledPatternProvider, VALLEPattern, ) from ..modules.conditioners import ( BaseConditioner, ChromaStemConditioner, CLAPEmbeddingConditioner, ConditionFuser, ConditioningProvider, LUTConditioner, T5Conditioner, ) from .unet import DiffusionUnet from .. import quantization as qt from ..utils.utils import dict_from_config from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: klass = { 'no_quant': qt.DummyQuantizer, 'rvq': qt.ResidualVectorQuantizer }[quantizer] kwargs = dict_from_config(getattr(cfg, quantizer)) if quantizer != 'no_quant': kwargs['dimension'] = dimension return klass(**kwargs) def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): if encoder_name == 'seanet': kwargs = dict_from_config(getattr(cfg, 'seanet')) encoder_override_kwargs = kwargs.pop('encoder') decoder_override_kwargs = kwargs.pop('decoder') encoder_kwargs = {**kwargs, **encoder_override_kwargs} decoder_kwargs = {**kwargs, **decoder_override_kwargs} encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs) decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) return encoder, decoder else: raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: """Instantiate a compression model.""" if cfg.compression_model == 'encodec': kwargs = dict_from_config(getattr(cfg, 'encodec')) encoder_name = kwargs.pop('autoencoder') quantizer_name = kwargs.pop('quantizer') encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) frame_rate = kwargs['sample_rate'] // encoder.hop_length renormalize = kwargs.pop('renormalize', False) # deprecated params kwargs.pop('renorm', None) return EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) else: raise KeyError(f"Unexpected compression model {cfg.compression_model}") def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: """Instantiate a transformer LM.""" if cfg.lm_model == 'transformer_lm': kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) n_q = kwargs['n_q'] q_modeling = kwargs.pop('q_modeling', None) codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] fuser = get_condition_fuser(cfg) condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically kwargs['cross_attention'] = True if codebooks_pattern_cfg.modeling is None: assert q_modeling is not None, \ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" codebooks_pattern_cfg = omegaconf.OmegaConf.create( {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} ) pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) return LMModel( pattern_provider=pattern_provider, condition_provider=condition_provider, fuser=fuser, cfg_dropout=cfg_prob, cfg_coef=cfg_coef, attribute_dropout=attribute_dropout, dtype=getattr(torch, cfg.dtype), device=cfg.device, **kwargs ).to(cfg.device) else: raise KeyError(f"Unexpected LM model {cfg.lm_model}") def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: """Instantiate a conditioning model.""" device = cfg.device duration = cfg.dataset.segment_duration cfg = getattr(cfg, 'conditioners') dict_cfg = {} if cfg is None else dict_from_config(cfg) conditioners: tp.Dict[str, BaseConditioner] = {} condition_provider_args = dict_cfg.pop('args', {}) condition_provider_args.pop('merge_text_conditions_p', None) condition_provider_args.pop('drop_desc_p', None) for cond, cond_cfg in dict_cfg.items(): model_type = cond_cfg['model'] model_args = cond_cfg[model_type] if model_type == 't5': conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) elif model_type == 'lut': conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) elif model_type == 'chroma_stem': conditioners[str(cond)] = ChromaStemConditioner( output_dim=output_dim, duration=duration, device=device, **model_args ) elif model_type == 'clap': conditioners[str(cond)] = CLAPEmbeddingConditioner( output_dim=output_dim, device=device, **model_args ) else: raise ValueError(f"Unrecognized conditioning model: {model_type}") conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) return conditioner def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: """Instantiate a condition fuser object.""" fuser_cfg = getattr(cfg, 'fuser') fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate'] fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) return fuser def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: """Instantiate a codebooks pattern provider object.""" pattern_providers = { 'parallel': ParallelPatternProvider, 'delay': DelayedPatternProvider, 'unroll': UnrolledPatternProvider, 'valle': VALLEPattern, 'musiclm': MusicLMPattern, } name = cfg.modeling kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} klass = pattern_providers[name] return klass(n_q, **kwargs) def get_debug_compression_model(device='cpu', sample_rate: int = 32000): """Instantiate a debug compression model to be used for unit tests.""" assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model" model_ratios = { 16000: [10, 8, 8], # 25 Hz at 16kHz 32000: [10, 8, 16] # 25 Hz at 32kHz } ratios: tp.List[int] = model_ratios[sample_rate] frame_rate = 25 seanet_kwargs: dict = { 'n_filters': 4, 'n_residual_layers': 1, 'dimension': 32, 'ratios': ratios, } print(seanet_kwargs) encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) init_x = torch.randn(8, 32, 128) quantizer(init_x, 1) # initialize kmeans etc. compression_model = EncodecModel( encoder, decoder, quantizer, frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device) return compression_model.eval() def get_diffusion_model(cfg: omegaconf.DictConfig): # TODO Find a way to infer the channels from dset channels = cfg.channels num_steps = cfg.schedule.num_steps return DiffusionUnet( chin=channels, num_steps=num_steps, **cfg.diffusion_unet) def get_processor(cfg, sample_rate: int = 24000): sample_processor = SampleProcessor() if cfg.use: kw = dict(cfg) kw.pop('use') kw.pop('name') if cfg.name == "multi_band_processor": sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) return sample_processor def get_debug_lm_model(device='cpu'): """Instantiate a debug LM to be used for unit tests.""" pattern = DelayedPatternProvider(n_q=4) dim = 16 providers = { 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), } condition_provider = ConditioningProvider(providers) fuser = ConditionFuser( {'cross': ['description'], 'prepend': [], 'sum': [], 'input_interpolate': []}) lm = LMModel( pattern, condition_provider, fuser, n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, cross_attention=True, causal=True) return lm.to(device).eval() def get_wrapped_compression_model( compression_model: CompressionModel, cfg: omegaconf.DictConfig) -> CompressionModel: # more to come. return compression_model ================================================ FILE: audiocraft/models/encodec.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Compression models or wrapper around existing models. Also defines the main interface that a model must follow to be usable as an audio tokenizer. """ from abc import ABC, abstractmethod import logging import math from pathlib import Path import typing as tp import numpy as np import torch from torch import nn from transformers import EncodecModel as HFEncodecModel from .. import quantization as qt logger = logging.getLogger() class CompressionModel(ABC, nn.Module): """Base API for all compression model that aim at being used as audio tokenizers with a language model. """ @abstractmethod def forward(self, x: torch.Tensor) -> qt.QuantizedResult: ... @abstractmethod def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """See `EncodecModel.encode`.""" ... @abstractmethod def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): """See `EncodecModel.decode`.""" ... @abstractmethod def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" ... @property @abstractmethod def channels(self) -> int: ... @property @abstractmethod def frame_rate(self) -> float: ... @property @abstractmethod def sample_rate(self) -> int: ... @property @abstractmethod def cardinality(self) -> int: ... @property @abstractmethod def num_codebooks(self) -> int: ... @property @abstractmethod def total_codebooks(self) -> int: ... @abstractmethod def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer.""" ... @staticmethod def get_pretrained( name: str, device: tp.Union[torch.device, str] = 'cpu' ) -> 'CompressionModel': """Instantiate a CompressionModel from a given pretrained model. Args: name (Path or str): name of the pretrained model. See after. device (torch.device or str): Device on which the model is loaded. Pretrained models: - dac_44khz (https://github.com/descriptinc/descript-audio-codec) - dac_24khz (same) - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz) - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz) - your own model on HugginFace. Export instructions to come... """ from . import builders, loaders model: CompressionModel if name in ['dac_44khz', 'dac_24khz']: model_type = name.split('_')[1] logger.info("Getting pretrained compression model from DAC %s", model_type) model = DAC(model_type) elif name in ['debug_compression_model']: logger.info("Getting pretrained compression model for debug") model = builders.get_debug_compression_model() elif Path(name).exists(): # We assume here if the paths exist that it is in fact an AC checkpoint # that was exported using `audiocraft.utils.export` functions. model = loaders.load_compression_model(name, device=device) else: logger.info("Getting pretrained compression model from HF %s", name) hf_model = HFEncodecModel.from_pretrained(name) model = HFEncodecCompressionModel(hf_model).to(device) return model.to(device).eval() class EncodecModel(CompressionModel): """Encodec model operating on the raw waveform. Args: encoder (nn.Module): Encoder network. decoder (nn.Module): Decoder network. quantizer (qt.BaseQuantizer): Quantizer network. frame_rate (int): Frame rate for the latent representation. sample_rate (int): Audio sample rate. channels (int): Number of audio channels. causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 def __init__(self, encoder: nn.Module, decoder: nn.Module, quantizer: qt.BaseQuantizer, frame_rate: int, sample_rate: int, channels: int, causal: bool = False, renormalize: bool = False): super().__init__() self.encoder = encoder self.decoder = decoder self.quantizer = quantizer self.frame_rate = frame_rate self.sample_rate = sample_rate self.channels = channels self.renormalize = renormalize self.causal = causal if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. assert not self.renormalize, 'Causal model does not support renormalize' @property def total_codebooks(self): """Total number of quantizer codebooks available.""" return self.quantizer.total_codebooks @property def num_codebooks(self): """Active number of codebooks used by the quantizer.""" return self.quantizer.num_codebooks def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer.""" self.quantizer.set_num_codebooks(n) @property def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() scale = 1e-8 + volume x = x / scale scale = scale.view(-1, 1) else: scale = None return x, scale def postprocess(self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) return x def forward(self, x: torch.Tensor) -> qt.QuantizedResult: assert x.dim() == 3 length = x.shape[-1] x, scale = self.preprocess(x) emb = self.encoder(x) q_res = self.quantizer(emb, self.frame_rate) out = self.decoder(q_res.x) # remove extra padding added by the encoder and decoder assert out.shape[-1] >= length, (out.shape[-1], length) out = out[..., :length] q_res.x = self.postprocess(out, scale) return q_res def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: x (torch.Tensor): Float tensor of shape [B, C, T] Returns: codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. scale a float tensor containing the scale for audio renormalizealization. """ assert x.dim() == 3 x, scale = self.preprocess(x) emb = self.encoder(x) codes = self.quantizer.encode(emb) return codes, scale def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): """Decode the given codes to a reconstructed representation, using the scale to perform audio denormalization if needed. Args: codes (torch.Tensor): Int tensor of shape [B, K, T] scale (torch.Tensor, optional): Float tensor containing the scale value. Returns: out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. """ emb = self.decode_latent(codes) out = self.decoder(emb) out = self.postprocess(out, scale) # out contains extra padding added by the encoder and decoder return out def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.quantizer.decode(codes) class DAC(CompressionModel): def __init__(self, model_type: str = "44khz"): super().__init__() try: import dac.utils except ImportError: raise RuntimeError("Could not import dac, make sure it is installed, " "please run `pip install descript-audio-codec`") self.model = dac.utils.load_model(model_type=model_type) self.n_quantizers = self.total_codebooks self.model.eval() def forward(self, x: torch.Tensor) -> qt.QuantizedResult: # We don't support training with this. raise NotImplementedError("Forward and training with DAC not supported.") def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: codes = self.model.encode(x, self.n_quantizers)[1] return codes, None def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): assert scale is None z_q = self.decode_latent(codes) return self.model.decode(z_q) def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.model.quantizer.from_codes(codes)[0] @property def channels(self) -> int: return 1 @property def frame_rate(self) -> float: return self.model.sample_rate / self.model.hop_length @property def sample_rate(self) -> int: return self.model.sample_rate @property def cardinality(self) -> int: return self.model.codebook_size @property def num_codebooks(self) -> int: return self.n_quantizers @property def total_codebooks(self) -> int: return self.model.n_codebooks def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer. """ assert n >= 1 assert n <= self.total_codebooks self.n_quantizers = n class HFEncodecCompressionModel(CompressionModel): """Wrapper around HuggingFace Encodec. """ def __init__(self, model: HFEncodecModel): super().__init__() self.model = model bws = self.model.config.target_bandwidths num_codebooks = [ bw * 1000 / (self.frame_rate * math.log2(self.cardinality)) for bw in bws ] deltas = [nc - int(nc) for nc in num_codebooks] # Checking we didn't do some bad maths and we indeed have integers! assert all(deltas) <= 1e-3, deltas self.possible_num_codebooks = [int(nc) for nc in num_codebooks] self.set_num_codebooks(max(self.possible_num_codebooks)) def forward(self, x: torch.Tensor) -> qt.QuantizedResult: # We don't support training with this. raise NotImplementedError("Forward and training with HF EncodecModel not supported.") def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks) bandwidth = self.model.config.target_bandwidths[bandwidth_index] res = self.model.encode(x, None, bandwidth) assert len(res[0]) == 1 assert len(res[1]) == 1 return res[0][0], res[1][0] def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): if scale is None: scales = [None] # type: ignore else: scales = scale # type: ignore res = self.model.decode(codes[None], scales) return res[0] def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.model.quantizer.decode(codes.transpose(0, 1)) @property def channels(self) -> int: return self.model.config.audio_channels @property def frame_rate(self) -> float: hop_length = int(np.prod(self.model.config.upsampling_ratios)) return self.sample_rate / hop_length @property def sample_rate(self) -> int: return self.model.config.sampling_rate @property def cardinality(self) -> int: return self.model.config.codebook_size @property def num_codebooks(self) -> int: return self._num_codebooks @property def total_codebooks(self) -> int: return max(self.possible_num_codebooks) def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer. """ if n not in self.possible_num_codebooks: raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") self._num_codebooks = n ================================================ FILE: audiocraft/models/lm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from functools import partial import logging import math import typing as tp import torch from torch import nn from ..utils import utils from ..modules.streaming import StreamingModule, State from ..modules.transformer import StreamingTransformer, create_norm_fn from ..modules.conditioners import ( ConditionFuser, ClassifierFreeGuidanceDropout, AttributeDropout, ConditioningProvider, ConditioningAttributes, ConditionType, ) from ..modules.codebooks_patterns import CodebooksPatternProvider from ..modules.activations import get_activation_fn logger = logging.getLogger(__name__) ConditionTensors = tp.Dict[str, ConditionType] CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): """LM layer initialization. Inspired from xlformers: https://github.com/fairinternal/xlformers Args: method (str): Method name for init function. Valid options are: 'gaussian', 'uniform'. input_dim (int): Input dimension of the initialized module. init_depth (int, optional): Optional init depth value used to rescale the standard deviation if defined. """ # Compute std std = 1 / math.sqrt(input_dim) # Rescale with depth if init_depth is not None: std = std / math.sqrt(2 * init_depth) if method == 'gaussian': return partial( torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std ) elif method == 'uniform': bound = math.sqrt(3) * std # ensure the standard deviation is `std` return partial(torch.nn.init.uniform_, a=-bound, b=bound) else: raise ValueError("Unsupported layer initialization method") def init_layer(m: nn.Module, method: str, init_depth: tp.Optional[int] = None, zero_bias_init: bool = False): """Wrapper around ``get_init_fn`` for proper initialization of LM modules. Args: m (nn.Module): Module to initialize. method (str): Method name for the init function. init_depth (int, optional): Optional init depth value used to rescale the standard deviation if defined. zero_bias_init (bool): Whether to initialize the bias to 0 or not. """ if isinstance(m, nn.Linear): init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: weight = m.weight.float() init_fn(weight) m.weight.data[:] = weight.half() else: init_fn(m.weight) if zero_bias_init and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Embedding): init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: weight = m.weight.float() init_fn(weight) m.weight.data[:] = weight.half() else: init_fn(m.weight) class ScaledEmbedding(nn.Embedding): """Boost learning rate for embeddings (with `scale`). """ def __init__(self, *args, lr=None, **kwargs): super().__init__(*args, **kwargs) self.lr = lr def make_optim_group(self): group = {"params": list(self.parameters())} if self.lr is not None: group["lr"] = self.lr return group @dataclass class LMOutput: # The logits are already re-aligned with the input codes # hence no extra shift is required, e.g. when computing CE logits: torch.Tensor # [B, K, T, card] mask: torch.Tensor # [B, K, T] class LMModel(StreamingModule): """Transformer-based language model on multiple streams of codes. Args: pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. condition_provider (MusicConditioningProvider): Conditioning provider from metadata. fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. n_q (int): Number of parallel streams to model. card (int): Cardinality, vocabulary size. dim (int): Dimension of the transformer encoder. num_heads (int): Number of heads for the transformer encoder. hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. norm (str): Normalization method. norm_first (bool): Use pre-norm instead of post-norm. emb_lr (float, optional): Embedding-specific learning rate. bias_proj (bool): Use bias for output projections. weight_init (str, optional): Method for weight initialization. depthwise_init (str, optional): Method for depthwise weight initialization. zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. cfg_dropout (float): Classifier-free guidance dropout. cfg_coef (float): Classifier-free guidance coefficient. attribute_dropout (dict): Attribute dropout probabilities. two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. **kwargs: Additional parameters for the transformer encoder. """ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, emb_lr: tp.Optional[float] = None, bias_proj: bool = True, weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, **kwargs): super().__init__() self.cfg_coef = cfg_coef self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) self.att_dropout = AttributeDropout(p=attribute_dropout) self.condition_provider = condition_provider self.fuser = fuser self.card = card embed_dim = self.card + 1 self.n_q = n_q self.dim = dim self.pattern_provider = pattern_provider self.two_step_cfg = two_step_cfg self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) if 'activation' in kwargs: kwargs['activation'] = get_activation_fn(kwargs['activation']) self.transformer = StreamingTransformer( d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), norm=norm, norm_first=norm_first, **kwargs) self.out_norm: tp.Optional[nn.Module] = None if norm_first: self.out_norm = create_norm_fn(norm, dim) self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) self._init_weights(weight_init, depthwise_init, zero_bias_init) self._fsdp: tp.Optional[nn.Module] self.__dict__['_fsdp'] = None def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): """Initialization of the transformer module weights. Args: weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: 'current' where the depth corresponds to the current layer index or 'global' where the total number of layer is used as depth. If not set, no depthwise initialization strategy is used. zero_bias_init (bool): Whether to initialize bias to zero or not. """ assert depthwise_init is None or depthwise_init in ['current', 'global'] assert depthwise_init is None or weight_init is not None, \ "If 'depthwise_init' is defined, a 'weight_init' method should be provided." assert not zero_bias_init or weight_init is not None, \ "If 'zero_bias_init', a 'weight_init' method should be provided" if weight_init is None: return for emb_layer in self.emb: init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) for layer_idx, tr_layer in enumerate(self.transformer.layers): depth = None if depthwise_init == 'current': depth = layer_idx + 1 elif depthwise_init == 'global': depth = len(self.transformer.layers) init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) tr_layer.apply(init_fn) for linear in self.linears: init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) @property def special_token_id(self) -> int: return self.card @property def num_codebooks(self) -> int: return self.n_q def forward(self, sequence: torch.Tensor, conditions: tp.List[ConditioningAttributes], condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: """Apply language model on sequence and conditions. Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and S the sequence steps, return the logits with shape [B, card, K, S]. Args: indices (torch.Tensor): Indices of the codes to model. conditions (list of ConditioningAttributes): Conditions to use when modeling the given codes. Note that when evaluating multiple time with the same conditioning you should pre-compute those and pass them as `condition_tensors`. condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning tensors, see `conditions`. Returns: torch.Tensor: Logits. """ B, K, S = sequence.shape assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) if condition_tensors is None: assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." # apply dropout modules conditions = self.cfg_dropout(conditions) conditions = self.att_dropout(conditions) tokenized = self.condition_provider.tokenize(conditions) # encode conditions and fuse, both have a streaming cache to not recompute when generating. condition_tensors = self.condition_provider(tokenized) else: assert not conditions, "Shouldn't pass both conditions and condition_tensors." input_, cross_attention_input = self.fuser(input_, condition_tensors) out = self.transformer(input_, cross_attention_src=cross_attention_input) if self.out_norm: out = self.out_norm(out) logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card] # remove the prefix from the model outputs if len(self.fuser.fuse2cond['prepend']) > 0: logits = logits[:, :, -S:] return logits # [B, K, S, card] def compute_predictions( self, codes: torch.Tensor, conditions: tp.List[ConditioningAttributes], condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput: """Given an input tensor of codes [B, K, T] and list of conditions, runs the model forward using the specified codes interleaving pattern. Args: codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, K the number of codebooks and T the number of timesteps. conditions (list of ConditioningAttributes): conditionings to use when modeling the given codes. Note that when evaluating multiple time with the same conditioning you should pre-compute those and pass them as `condition_tensors`. condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning tensors, see `conditions`. Returns: LMOutput: Language model outputs logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, i.e. the first item corresponds to logits to predict the first code, meaning that no additional shifting of codes and logits is required. mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. Given the specified interleaving strategies, parts of the logits and codes should not be considered as valid predictions because of invalid context. """ B, K, T = codes.shape codes = codes.contiguous() # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens pattern = self.pattern_provider.get_pattern(T) sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( codes, self.special_token_id, keep_only_valid_steps=True ) # apply model on pattern sequence model = self if self._fsdp is None else self._fsdp logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card] # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] # and provide the corresponding mask over invalid positions of tokens logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] # note: we use nans as special token to make it obvious if we feed unexpected logits logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( logits, float('nan'), keep_only_valid_steps=True ) logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] return LMOutput(logits, logits_mask) def _sample_next_token(self, sequence: torch.Tensor, cfg_conditions: CFGConditions, unconditional_state: State, use_sampling: bool = False, temp: float = 1.0, top_k: int = 0, top_p: float = 0.0, cfg_coef: tp.Optional[float] = None) -> torch.Tensor: """Sample next token from the model given a sequence and a set of conditions. The model supports multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). Args: sequence (torch.Tensor): Current sequence of shape [B, K, S] with K corresponding to the number of codebooks and S the number of sequence steps. S = 1 in streaming mode, except for the first step that contains a bigger prompt. condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, should be twice the batch size, being the concatenation of the conditions + null conditions. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. cfg_coef (float, optional): classifier free guidance coefficient Returns: next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. """ B = sequence.shape[0] cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef model = self if self._fsdp is None else self._fsdp if self.two_step_cfg and cfg_conditions != {}: assert isinstance(cfg_conditions, tuple), type(cfg_conditions) condition_tensors, null_condition_tensors = cfg_conditions cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) state = self.get_streaming_state() self.set_streaming_state(unconditional_state) uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors) unconditional_state.update(self.get_streaming_state()) self.set_streaming_state(state) logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef else: assert isinstance(cfg_conditions, dict) condition_tensors = cfg_conditions if condition_tensors: # Preparing for CFG, predicting both conditional and unconditional logits. sequence = torch.cat([sequence, sequence], dim=0) all_logits = model( sequence, conditions=[], condition_tensors=condition_tensors) if condition_tensors: cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef else: logits = all_logits logits = logits.permute(0, 1, 3, 2) # [B, K, card, T] logits = logits[..., -1] # [B x K x card] # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. if use_sampling and temp > 0.0: probs = torch.softmax(logits / temp, dim=-1) if top_p > 0.0: next_token = utils.sample_top_p(probs, p=top_p) elif top_k > 0: next_token = utils.sample_top_k(probs, k=top_k) else: next_token = utils.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) return next_token @torch.no_grad() def generate(self, prompt: tp.Optional[torch.Tensor] = None, conditions: tp.List[ConditioningAttributes] = [], num_samples: tp.Optional[int] = None, max_gen_len: int = 256, use_sampling: bool = True, temp: float = 1.0, top_k: int = 250, top_p: float = 0.0, cfg_coef: tp.Optional[float] = None, two_step_cfg: tp.Optional[bool] = None, remove_prompts: bool = False, check: bool = False, callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: """Generate tokens sampling from the model given a prompt or unconditionally. Generation can be perform in a greedy fashion or using sampling with top K and top P strategies. Args: prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. conditions_tensors (list of ConditioningAttributes, optional): List of conditions. num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. max_gen_len (int): Maximum generation length. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. cfg_coeff (float, optional): Classifier-free guidance coefficient. two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. remove_prompts (bool): Whether to remove prompts from generation or not. check (bool): Whether to apply further checks on generated sequence. callback (Callback, optional): Callback function to report generation progress. Returns: torch.Tensor: Generated tokens. """ assert not self.training, "generation shouldn't be used in training mode." first_param = next(iter(self.parameters())) device = first_param.device # Checking all input shapes are consistent. possible_num_samples = [] if num_samples is not None: possible_num_samples.append(num_samples) elif prompt is not None: possible_num_samples.append(prompt.shape[0]) elif conditions: possible_num_samples.append(len(conditions)) else: possible_num_samples.append(1) assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" num_samples = possible_num_samples[0] # below we create set of conditions: one conditional and one unconditional # to do that we merge the regular condition together with the null condition # we then do 1 forward pass instead of 2. # the reason for that is two-fold: # 1. it is about x2 faster than doing 2 forward passes # 2. avoid the streaming API treating the 2 passes as part of different time steps # We also support doing two different passes, in particular to ensure that # the padding structure is exactly the same between train and test. # With a batch size of 1, this can be slower though. cfg_conditions: CFGConditions two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg if conditions: null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) if two_step_cfg: cfg_conditions = ( self.condition_provider(self.condition_provider.tokenize(conditions)), self.condition_provider(self.condition_provider.tokenize(null_conditions)), ) else: conditions = conditions + null_conditions tokenized = self.condition_provider.tokenize(conditions) cfg_conditions = self.condition_provider(tokenized) else: cfg_conditions = {} if prompt is None: assert num_samples > 0 prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) B, K, T = prompt.shape start_offset = T assert start_offset < max_gen_len pattern = self.pattern_provider.get_pattern(max_gen_len) # this token is used as default value for codes that are not generated yet unknown_token = -1 # we generate codes up to the max_gen_len that will be mapped to the pattern sequence gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) # filling the gen_codes with the prompt if needed gen_codes[..., :start_offset] = prompt # create the gen_sequence with proper interleaving from the pattern: [B, K, S] gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) # retrieve the start_offset in the sequence: # it is the first sequence step that contains the `start_offset` timestep start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) assert start_offset_sequence is not None with self.streaming(): unconditional_state = self.get_streaming_state() prev_offset = 0 gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S] for offset in range(start_offset_sequence, gen_sequence_len): # get current sequence (note that the streaming API is providing the caching over previous offsets) curr_sequence = gen_sequence[..., prev_offset:offset] curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) if check: # check coherence between mask and sequence assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() # should never happen as gen_sequence is filled progressively assert not (curr_sequence == unknown_token).any() # sample next token from the model, next token shape is [B, K, 1] next_token = self._sample_next_token( curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, cfg_coef=cfg_coef) # ensure the tokens that should be masked are properly set to special_token_id # as the model never output special_token_id valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) next_token[~valid_mask] = self.special_token_id # ensure we don't overwrite prompt tokens, we only write over unknown tokens # (then mask tokens should be left as is as well, which is correct) gen_sequence[..., offset:offset+1] = torch.where( gen_sequence[..., offset:offset+1] == unknown_token, next_token, gen_sequence[..., offset:offset+1] ) prev_offset = offset if callback is not None: callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) unconditional_state.clear() # ensure sequence has been entirely filled assert not (gen_sequence == unknown_token).any() # ensure gen_sequence pattern and mask are matching # which means the gen_sequence is valid according to the pattern assert ( gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) ).all() # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) # sanity checks over the returned codes and corresponding masks assert (out_codes[..., :max_gen_len] != unknown_token).all() assert (out_mask[..., :max_gen_len] == 1).all() out_start_offset = start_offset if remove_prompts else 0 out_codes = out_codes[..., out_start_offset:max_gen_len] # ensure the returned codes are all valid assert (out_codes >= 0).all() and (out_codes <= self.card).all() return out_codes ================================================ FILE: audiocraft/models/loaders.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Utility functions to load from the checkpoints. Each checkpoint is a torch.saved dict with the following keys: - 'xp.cfg': the hydra config as dumped during training. This should be used to rebuild the object using the audiocraft.models.builders functions, - 'model_best_state': a readily loadable best state for the model, including the conditioner. The model obtained from `xp.cfg` should be compatible with this state dict. In the case of a LM, the encodec model would not be bundled along but instead provided separately. Those functions also support loading from a remote location with the Torch Hub API. They also support overriding some parameters, in particular the device and dtype of the returned model. """ from pathlib import Path from huggingface_hub import hf_hub_download import typing as tp import os from omegaconf import OmegaConf, DictConfig import torch from . import builders from .encodec import CompressionModel def get_audiocraft_cache_dir() -> tp.Optional[str]: return os.environ.get('AUDIOCRAFT_CACHE_DIR', None) def _get_state_dict( file_or_url_or_id: tp.Union[Path, str], filename: tp.Optional[str] = None, device='cpu', cache_dir: tp.Optional[str] = None, ): if cache_dir is None: cache_dir = get_audiocraft_cache_dir() # Return the state dict either from a file or url file_or_url_or_id = str(file_or_url_or_id) assert isinstance(file_or_url_or_id, str) if os.path.isfile(file_or_url_or_id): return torch.load(file_or_url_or_id, map_location=device) if os.path.isdir(file_or_url_or_id): file = f"{file_or_url_or_id}/{filename}" return torch.load(file, map_location=device) elif file_or_url_or_id.startswith('https://'): return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) else: assert filename is not None, "filename needs to be defined if using HF checkpoints" file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir) return torch.load(file, map_location=device) def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) if 'pretrained' in pkg: return CompressionModel.get_pretrained(pkg['pretrained'], device=device) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) model = builders.get_compression_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() return model def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) def _delete_param(cfg: DictConfig, full_name: str): parts = full_name.split('.') for part in parts[:-1]: if part in cfg: cfg = cfg[part] else: return OmegaConf.set_struct(cfg, False) if parts[-1] in cfg: del cfg[parts[-1]] OmegaConf.set_struct(cfg, True) def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir) cfg = OmegaConf.create(pkg['xp.cfg']) cfg.device = str(device) if cfg.device == 'cpu': cfg.dtype = 'float32' else: cfg.dtype = 'float16' _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path') _delete_param(cfg, 'conditioners.args.merge_text_conditions_p') _delete_param(cfg, 'conditioners.args.drop_desc_p') model = builders.get_lm_model(cfg) model.load_state_dict(pkg['best_state']) model.eval() model.cfg = cfg return model def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None): return _get_state_dict(file_or_url_or_id, filename="all_in_one.pt", cache_dir=cache_dir) def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): pkg = load_mbd_ckpt(file_or_url_or_id, cache_dir=cache_dir) models = [] processors = [] cfgs = [] sample_rate = pkg['sample_rate'] for i in range(pkg['n_bands']): cfg = pkg[i]['cfg'] model = builders.get_diffusion_model(cfg) model_dict = pkg[i]['model_state'] model.load_state_dict(model_dict) model.to(device) processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate) processor_dict = pkg[i]['processor_state'] processor.load_state_dict(processor_dict) processor.to(device) models.append(model) processors.append(processor) cfgs.append(cfg) return models, processors, cfgs ================================================ FILE: audiocraft/models/multibanddiffusion.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Multi Band Diffusion models as described in "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" (paper link). """ import typing as tp import torch import julius from .unet import DiffusionUnet from ..modules.diffusion_schedule import NoiseSchedule from .encodec import CompressionModel from ..solvers.compression import CompressionSolver from .loaders import load_compression_model, load_diffusion_models class DiffusionProcess: """Sampling for a diffusion Model. Args: model (DiffusionUnet): Diffusion U-Net model. noise_schedule (NoiseSchedule): Noise schedule for diffusion process. """ def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None: """ """ self.model = model self.schedule = noise_schedule def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor, step_list: tp.Optional[tp.List[int]] = None): """Perform one diffusion process to generate one of the bands. Args: condition (tensor): The embeddings form the compression model. initial_noise (tensor): The initial noise to start the process/ """ return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list, condition=condition) class MultiBandDiffusion: """Sample from multiple diffusion models. Args: DPs (list of DiffusionProcess): Diffusion processes. codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens. """ def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None: self.DPs = DPs self.codec_model = codec_model self.device = next(self.codec_model.parameters()).device @property def sample_rate(self) -> int: return self.codec_model.sample_rate @staticmethod def get_mbd_musicgen(device=None): """Load our diffusion models trained for MusicGen.""" if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' path = 'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_musicgen_32khz.th' name = 'facebook/musicgen-small' codec_model = load_compression_model(name, device=device) models, processors, cfgs = load_diffusion_models(path, device=device) DPs = [] for i in range(len(models)): schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) @staticmethod def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True, device: tp.Optional[tp.Union[torch.device, str]] = None, n_q: tp.Optional[int] = None): """Get the pretrained Models for MultibandDiffusion. Args: bw (float): Bandwidth of the compression model. pretrained (bool): Whether to use / download if necessary the models. device (torch.device or str, optional): Device on which the models are loaded. n_q (int, optional): Number of quantizers to use within the compression model. """ if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available" if n_q is not None: assert n_q in [2, 4, 8] assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \ f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}" n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw] codec_model = CompressionSolver.model_from_checkpoint( '//pretrained/facebook/encodec_24khz', device=device) codec_model.set_num_codebooks(n_q) codec_model = codec_model.to(device) path = f'https://dl.fbaipublicfiles.com/encodec/Diffusion/mbd_comp_{n_q}.pt' models, processors, cfgs = load_diffusion_models(path, device=device) DPs = [] for i in range(len(models)): schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device) DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule)) return MultiBandDiffusion(DPs=DPs, codec_model=codec_model) return MultiBandDiffusion(DPs, codec_model) @torch.no_grad() def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: """Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform. Args: wav (torch.Tensor): The audio that we want to extract the conditioning from sample_rate (int): sample rate of the audio""" if sample_rate != self.sample_rate: wav = julius.resample_frac(wav, sample_rate, self.sample_rate) codes, scale = self.codec_model.encode(wav) assert scale is None, "Scaled compression models not supported." emb = self.get_emb(codes) return emb @torch.no_grad() def get_emb(self, codes: torch.Tensor): """Get latent representation from the discrete codes Argrs: codes (torch.Tensor): discrete tokens""" emb = self.codec_model.decode_latent(codes) return emb def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None, step_list: tp.Optional[tp.List[int]] = None): """Generate Wavform audio from the latent embeddings of the compression model Args: emb (torch.Tensor): Conditioning embeddinds size (none torch.Size): size of the output if None this is computed from the typical upsampling of the model step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step. """ if size is None: upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate) size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling]) assert size[0] == emb.size(0) out = torch.zeros(size).to(self.device) for DP in self.DPs: out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out)) return out def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1): """match the eq to the encodec output by matching the standard deviation of some frequency bands Args: wav (torch.Tensor): audio to equalize ref (torch.Tensor):refenrence audio from which we match the spectrogram. n_bands (int): number of bands of the eq strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching. """ split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device) bands = split(wav) bands_ref = split(ref) out = torch.zeros_like(ref) for i in range(n_bands): out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness return out def regenerate(self, wav: torch.Tensor, sample_rate: int): """Regenerate a wavform through compression and diffusion regeneration. Args: wav (torch.Tensor): Original 'ground truth' audio sample_rate (int): sample rate of the input (and output) wav """ if sample_rate != self.codec_model.sample_rate: wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate) emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate) size = wav.size() out = self.generate(emb, size=size) if sample_rate != self.codec_model.sample_rate: out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate) return out def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32): """Generate Waveform audio with diffusion from the discrete codes. Args: tokens (torch.Tensor): discrete codes n_bands (int): bands for the eq matching. """ wav_encodec = self.codec_model.decode(tokens) condition = self.get_emb(tokens) wav_diffusion = self.generate(emb=condition, size=wav_encodec.size()) return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands) ================================================ FILE: audiocraft/models/musicgen.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Main model for using MusicGen. This will combine all the required components and provide easy access to the generation API. """ import typing as tp import warnings import torch from .encodec import CompressionModel from .lm import LMModel from .builders import get_debug_compression_model, get_debug_lm_model from .loaders import load_compression_model, load_lm_model from ..data.audio_utils import convert_audio from ..modules.conditioners import ConditioningAttributes, WavCondition from ..utils.autocast import TorchAutocast MelodyList = tp.List[tp.Optional[torch.Tensor]] MelodyType = tp.Union[torch.Tensor, MelodyList] # backward compatible names mapping _HF_MODEL_CHECKPOINTS_MAP = { "small": "GrandaddyShmax/musicgen-small", "medium": "GrandaddyShmax/musicgen-medium", "large": "GrandaddyShmax/musicgen-large", "melody": "GrandaddyShmax/musicgen-melody", } class MusicGen: """MusicGen main model with convenient generation API. Args: name (str): name of the model. compression_model (CompressionModel): Compression model used to map audio to invertible discrete representations. lm (LMModel): Language model over discrete representations. max_duration (float, optional): maximum duration the model can produce, otherwise, inferred from the training params. """ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = None): self.name = name self.compression_model = compression_model self.lm = lm if max_duration is None: if hasattr(lm, 'cfg'): max_duration = lm.cfg.dataset.segment_duration # type: ignore else: raise ValueError("You must provide max_duration when building directly MusicGen") assert max_duration is not None self.max_duration: float = max_duration self.device = next(iter(lm.parameters())).device self.generation_params: dict = {} self.set_generation_params(duration=15) # 15 seconds by default self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None if self.device.type == 'cpu': self.autocast = TorchAutocast(enabled=False) else: self.autocast = TorchAutocast( enabled=True, device_type=self.device.type, dtype=torch.float16) @property def frame_rate(self) -> float: """Roughly the number of AR steps per seconds.""" return self.compression_model.frame_rate @property def sample_rate(self) -> int: """Sample rate of the generated audio.""" return self.compression_model.sample_rate @property def audio_channels(self) -> int: """Audio channels of the generated audio.""" return self.compression_model.channels @staticmethod def get_pretrained(name: str = 'GrandaddyShmax/musicgen-melody', device=None): """Return pretrained model, we provide four models: - facebook/musicgen-small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small - facebook/musicgen-medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium - facebook/musicgen-melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody - facebook/musicgen-large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large """ if device is None: if torch.cuda.device_count(): device = 'cuda' else: device = 'cpu' if name == 'debug': # used only for unit tests compression_model = get_debug_compression_model(device) lm = get_debug_lm_model(device) return MusicGen(name, compression_model, lm, max_duration=30) lm = load_lm_model(name, device=device) compression_model = load_compression_model(name, device=device) if 'self_wav' in lm.condition_provider.conditioners: lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True return MusicGen(name, compression_model, lm) def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 30.0, cfg_coef: float = 3.0, two_step_cfg: bool = False, extend_stride: float = 18): """Set the generation parameters for MusicGen. Args: use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. top_k (int, optional): top_k used for sampling. Defaults to 250. top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. duration (float, optional): Duration of the generated waveform. Defaults to 30.0. cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice. extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much should we extend the audio each time. Larger values will mean less context is preserved, and shorter value will require extra computations. """ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration." self.extend_stride = extend_stride self.duration = duration self.generation_params = { 'use_sampling': use_sampling, 'temp': temperature, 'top_k': top_k, 'top_p': top_p, 'cfg_coef': cfg_coef, 'two_step_cfg': two_step_cfg, } def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): """Override the default progress callback.""" self._progress_callback = progress_callback def generate_unconditional(self, num_samples: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples in an unconditional manner. Args: num_samples (int): Number of samples to be generated. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text. Args: descriptions (list of str): A list of strings used as text conditioning. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) assert prompt_tokens is None tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType, melody_sample_rate: int, progress: bool = False, return_tokens: bool = False) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on text and melody. Args: descriptions (list of str): A list of strings used as text conditioning. melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as melody conditioning. Should have shape [B, C, T] with B matching the description length, C=1 or 2. It can be [C, T] if there is a single description. It can also be a list of [C, T] tensors. melody_sample_rate: (int): Sample rate of the melody waveforms. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ if isinstance(melody_wavs, torch.Tensor): if melody_wavs.dim() == 2: melody_wavs = melody_wavs[None] if melody_wavs.dim() != 3: raise ValueError("Melody wavs should have a shape [B, C, T].") melody_wavs = list(melody_wavs) else: for melody in melody_wavs: if melody is not None: assert melody.dim() == 2, "One melody in the list has the wrong number of dims." melody_wavs = [ convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels) if wav is not None else None for wav in melody_wavs] attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None, melody_wavs=melody_wavs) assert prompt_tokens is None tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, progress: bool = False, return_tokens: bool = False) \ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: """Generate samples conditioned on audio prompts. Args: prompt (torch.Tensor): A batch of waveforms used for continuation. Prompt should be [B, C, T], or [C, T] if only one sample is generated. prompt_sample_rate (int): Sampling rate of the given audio waveforms. descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. """ if prompt.dim() == 2: prompt = prompt[None] if prompt.dim() != 3: raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) if descriptions is None: descriptions = [None] * len(prompt) attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) assert prompt_tokens is not None tokens = self._generate_tokens(attributes, prompt_tokens, progress) if return_tokens: return self.generate_audio(tokens), tokens return self.generate_audio(tokens) @torch.no_grad() def _prepare_tokens_and_attributes( self, descriptions: tp.Sequence[tp.Optional[str]], prompt: tp.Optional[torch.Tensor], melody_wavs: tp.Optional[MelodyList] = None, ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: """Prepare model inputs. Args: descriptions (list of str): A list of strings used as text conditioning. prompt (torch.Tensor): A batch of waveforms used for continuation. melody_wavs (torch.Tensor, optional): A batch of waveforms used as melody conditioning. Defaults to None. """ attributes = [ ConditioningAttributes(text={'description': description}) for description in descriptions] if melody_wavs is None: for attr in attributes: attr.wav['self_wav'] = WavCondition( torch.zeros((1, 1, 1), device=self.device), torch.tensor([0], device=self.device), sample_rate=[self.sample_rate], path=[None]) else: if 'self_wav' not in self.lm.condition_provider.conditioners: raise RuntimeError("This model doesn't support melody conditioning. " "Use the `melody` model.") assert len(melody_wavs) == len(descriptions), \ f"number of melody wavs must match number of descriptions! " \ f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}" for attr, melody in zip(attributes, melody_wavs): if melody is None: attr.wav['self_wav'] = WavCondition( torch.zeros((1, 1, 1), device=self.device), torch.tensor([0], device=self.device), sample_rate=[self.sample_rate], path=[None]) else: attr.wav['self_wav'] = WavCondition( melody[None].to(device=self.device), torch.tensor([melody.shape[-1]], device=self.device), sample_rate=[self.sample_rate], path=[None], ) if prompt is not None: if descriptions is not None: assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" prompt = prompt.to(self.device) prompt_tokens, scale = self.compression_model.encode(prompt) assert scale is None else: prompt_tokens = None return attributes, prompt_tokens def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: """Generate discrete audio tokens given audio prompt and/or conditions. Args: attributes (list of ConditioningAttributes): Conditions used for generation (text/melody). prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. progress (bool, optional): Flag to display progress of the generation process. Defaults to False. Returns: torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. """ i = 0 prompt_list = attributes[0].text['description'] total_gen_len = int(self.duration * self.frame_rate) max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) current_gen_offset: int = 0 def _progress_callback(generated_tokens: int, tokens_to_generate: int): generated_tokens += current_gen_offset if current_gen_offset > 0: generated_tokens += (self.max_duration - self.extend_stride) * self.frame_rate if self._progress_callback is not None: # Note that total_gen_len might be quite wrong depending on the # codebook pattern used, but with delay it is almost accurate. self._progress_callback(generated_tokens, total_gen_len) else: print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') if prompt_tokens is not None: assert max_prompt_len >= prompt_tokens.shape[-1], \ "Prompt is longer than audio to generate" callback = None if progress: callback = _progress_callback if self.duration <= self.max_duration: # generate by sampling from LM, simple case. with self.autocast: attributes[0].text['description'] = prompt_list[0] gen_tokens = self.lm.generate( prompt_tokens, attributes, callback=callback, max_gen_len=total_gen_len, **self.generation_params) else: # now this gets a bit messier, we need to handle prompts, # melody conditioning etc. ref_wavs = [attr.wav['self_wav'] for attr in attributes] all_tokens = [] if prompt_tokens is None: prompt_length = 0 else: all_tokens.append(prompt_tokens) prompt_length = prompt_tokens.shape[-1] stride_tokens = int(self.frame_rate * self.extend_stride) while current_gen_offset + prompt_length < total_gen_len: time_offset = current_gen_offset / self.frame_rate chunk_duration = min(self.duration - time_offset, self.max_duration) max_gen_len = int(chunk_duration * self.frame_rate) for attr, ref_wav in zip(attributes, ref_wavs): wav_length = ref_wav.length.item() if wav_length == 0: continue # We will extend the wav periodically if it not long enough. # we have to do it here rather than in conditioners.py as otherwise # we wouldn't have the full wav. initial_position = int(time_offset * self.sample_rate) wav_target_length = int(self.max_duration * self.sample_rate) positions = torch.arange(initial_position, initial_position + wav_target_length, device=self.device) attr.wav['self_wav'] = WavCondition( ref_wav[0][..., positions % wav_length], torch.full_like(ref_wav[1], wav_target_length), [self.sample_rate] * ref_wav[0].size(0), [None], [0.]) with self.autocast: if i >= len(prompt_list): i = len(prompt_list) - 1 attributes[0].text['description'] = prompt_list[i] gen_tokens = self.lm.generate( prompt_tokens, attributes, callback=callback, max_gen_len=max_gen_len, **self.generation_params) i = i + 1 if prompt_tokens is None: all_tokens.append(gen_tokens) else: all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) prompt_tokens = gen_tokens[:, :, stride_tokens:] prompt_length = prompt_tokens.shape[-1] current_gen_offset += stride_tokens gen_tokens = torch.cat(all_tokens, dim=-1) return gen_tokens def generate_audio(self, gen_tokens: torch.Tensor): """Generate Audio from tokens""" assert gen_tokens.dim() == 3 with torch.no_grad(): gen_audio = self.compression_model.decode(gen_tokens, None) return gen_audio def to(self, device: str): self.compression_model.to(device) self.lm.to(device) return self ================================================ FILE: audiocraft/models/unet.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Pytorch Unet Module used for diffusion. """ from dataclasses import dataclass import typing as tp import torch from torch import nn from torch.nn import functional as F from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding @dataclass class Output: sample: torch.Tensor def get_model(cfg, channels: int, side: int, num_steps: int): if cfg.model == 'unet': return DiffusionUnet( chin=channels, num_steps=num_steps, **cfg.diffusion_unet) else: raise RuntimeError('Not Implemented') class ResBlock(nn.Module): def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4, dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, dropout: float = 0.): super().__init__() stride = 1 padding = dilation * (kernel - stride) // 2 Conv = nn.Conv1d Drop = nn.Dropout1d self.norm1 = nn.GroupNorm(norm_groups, channels) self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) self.activation1 = activation() self.dropout1 = Drop(dropout) self.norm2 = nn.GroupNorm(norm_groups, channels) self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation) self.activation2 = activation() self.dropout2 = Drop(dropout) def forward(self, x): h = self.dropout1(self.conv1(self.activation1(self.norm1(x)))) h = self.dropout2(self.conv2(self.activation2(self.norm2(h)))) return x + h class DecoderLayer(nn.Module): def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, dropout: float = 0.): super().__init__() padding = (kernel - stride) // 2 self.res_blocks = nn.Sequential( *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) for idx in range(res_blocks)]) self.norm = nn.GroupNorm(norm_groups, chin) ConvTr = nn.ConvTranspose1d self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False) self.activation = activation() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.res_blocks(x) x = self.norm(x) x = self.activation(x) x = self.convtr(x) return x class EncoderLayer(nn.Module): def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2, norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU, dropout: float = 0.): super().__init__() padding = (kernel - stride) // 2 Conv = nn.Conv1d self.conv = Conv(chin, chout, kernel, stride, padding, bias=False) self.norm = nn.GroupNorm(norm_groups, chout) self.activation = activation() self.res_blocks = nn.Sequential( *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout) for idx in range(res_blocks)]) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, T = x.shape stride, = self.conv.stride pad = (stride - (T % stride)) % stride x = F.pad(x, (0, pad)) x = self.conv(x) x = self.norm(x) x = self.activation(x) x = self.res_blocks(x) return x class BLSTM(nn.Module): """BiLSTM with same hidden units as input dim. """ def __init__(self, dim, layers=2): super().__init__() self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) self.linear = nn.Linear(2 * dim, dim) def forward(self, x): x = x.permute(2, 0, 1) x = self.lstm(x)[0] x = self.linear(x) x = x.permute(1, 2, 0) return x class DiffusionUnet(nn.Module): def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2., max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False, bilstm: bool = False, transformer: bool = False, codec_dim: tp.Optional[int] = None, **kwargs): super().__init__() self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() self.embeddings: tp.Optional[nn.ModuleList] = None self.embedding = nn.Embedding(num_steps, hidden) if emb_all_layers: self.embeddings = nn.ModuleList() self.condition_embedding: tp.Optional[nn.Module] = None for d in range(depth): encoder = EncoderLayer(chin, hidden, **kwargs) decoder = DecoderLayer(hidden, chin, **kwargs) self.encoders.append(encoder) self.decoders.insert(0, decoder) if emb_all_layers and d > 0: assert self.embeddings is not None self.embeddings.append(nn.Embedding(num_steps, hidden)) chin = hidden hidden = min(int(chin * growth), max_channels) self.bilstm: tp.Optional[nn.Module] if bilstm: self.bilstm = BLSTM(chin) else: self.bilstm = None self.use_transformer = transformer self.cross_attention = False if transformer: self.cross_attention = cross_attention self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False, cross_attention=cross_attention) self.use_codec = False if codec_dim is not None: self.conv_codec = nn.Conv1d(codec_dim, chin, 1) self.use_codec = True def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None): skips = [] bs = x.size(0) z = x view_args = [1] if type(step) is torch.Tensor: step_tensor = step else: step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs) for idx, encoder in enumerate(self.encoders): z = encoder(z) if idx == 0: z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z) elif self.embeddings is not None: z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z) skips.append(z) if self.use_codec: # insert condition in the bottleneck assert condition is not None, "Model defined for conditionnal generation" condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim assert condition_emb.size(-1) <= 2 * z.size(-1), \ f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}" if not self.cross_attention: condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1)) assert z.size() == condition_emb.size() z += condition_emb cross_attention_src = None else: cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C B, T, C = cross_attention_src.shape positions = torch.arange(T, device=x.device).view(1, -1, 1) pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype) cross_attention_src = cross_attention_src + pos_emb if self.use_transformer: z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1) else: if self.bilstm is None: z = torch.zeros_like(z) else: z = self.bilstm(z) for decoder in self.decoders: s = skips.pop(-1) z = z[:, :, :s.shape[2]] z = z + s z = decoder(z) z = z[:, :, :x.shape[2]] return Output(z) ================================================ FILE: audiocraft/modules/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Modules used for building the models.""" # flake8: noqa from .conv import ( NormConv1d, NormConv2d, NormConvTranspose1d, NormConvTranspose2d, StreamableConv1d, StreamableConvTranspose1d, pad_for_conv1d, pad1d, unpad1d, ) from .lstm import StreamableLSTM from .seanet import SEANetEncoder, SEANetDecoder from .transformer import StreamingTransformer ================================================ FILE: audiocraft/modules/activations.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch import Tensor from typing import Union, Callable class CustomGLU(nn.Module): """Custom Gated Linear Unit activation. Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation function (i.e. sigmoid, swish, etc.). Args: activation (nn.Module): The custom activation to apply in the Gated Linear Unit dim (int): the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` Examples:: >>> m = CustomGLU(nn.Sigmoid()) >>> input = torch.randn(4, 2) >>> output = m(input) """ def __init__(self, activation: nn.Module, dim: int = -1): super(CustomGLU, self).__init__() self.dim = dim self.activation = activation def forward(self, x: Tensor): assert x.shape[self.dim] % 2 == 0 # M = N / 2 a, b = torch.chunk(x, 2, dim=self.dim) return a * self.activation(b) class SwiGLU(CustomGLU): """SiLU Gated Linear Unit activation. Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 """ def __init__(self, dim: int = -1): super(SwiGLU, self).__init__(nn.SiLU(), dim) class GeGLU(CustomGLU): """GeLU Gated Linear Unit activation. Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 """ def __init__(self, dim: int = -1): super(GeGLU, self).__init__(nn.GELU(), dim) class ReGLU(CustomGLU): """ReLU Gated Linear Unit activation. Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is the first half of the input matrices, :math:`b` is the second half. Args: dim (int): the dimension on which to split the input. Default: -1 """ def __init__(self, dim: int = -1): super(ReGLU, self).__init__(nn.ReLU(), dim) def get_activation_fn( activation: Union[str, Callable[[Tensor], Tensor]] ) -> Union[str, Callable[[Tensor], Tensor]]: """Helper function to map an activation string to the activation class. If the supplied activation is not a string that is recognized, the activation is passed back. Args: activation (str, or Callable[[Tensor], Tensor]): Activation to check """ if isinstance(activation, str): if activation == "reglu": return ReGLU() elif activation == "geglu": return GeGLU() elif activation == "swiglu": return SwiGLU() return activation ================================================ FILE: audiocraft/modules/chroma.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp from einops import rearrange from librosa import filters import torch from torch import nn import torch.nn.functional as F import torchaudio class ChromaExtractor(nn.Module): """Chroma extraction and quantization. Args: sample_rate (int): Sample rate for the chroma extraction. n_chroma (int): Number of chroma bins for the chroma extraction. radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). nfft (int, optional): Number of FFT. winlen (int, optional): Window length. winhop (int, optional): Window hop size. argmax (bool, optional): Whether to use argmax. Defaults to False. norm (float, optional): Norm for chroma normalization. Defaults to inf. """ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, norm: float = torch.inf): super().__init__() self.winlen = winlen or 2 ** radix2_exp self.nfft = nfft or self.winlen self.winhop = winhop or (self.winlen // 4) self.sample_rate = sample_rate self.n_chroma = n_chroma self.norm = norm self.argmax = argmax self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, n_chroma=self.n_chroma)), persistent=False) self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, hop_length=self.winhop, power=2, center=True, pad=0, normalized=True) def forward(self, wav: torch.Tensor) -> torch.Tensor: T = wav.shape[-1] # in case we are getting a wav that was dropped out (nullified) # from the conditioner, make sure wav length is no less that nfft if T < self.nfft: pad = self.nfft - T r = 0 if pad % 2 == 0 else 1 wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" spec = self.spec(wav).squeeze(1) raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') if self.argmax: idx = norm_chroma.argmax(-1, keepdim=True) norm_chroma[:] = 0 norm_chroma.scatter_(dim=-1, index=idx, value=1) return norm_chroma ================================================ FILE: audiocraft/modules/codebooks_patterns.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import namedtuple from dataclasses import dataclass from functools import lru_cache import logging import typing as tp from abc import ABC, abstractmethod import torch LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates logger = logging.getLogger(__name__) @dataclass class Pattern: """Base implementation of a pattern over a sequence with multiple codebooks. The codebook pattern consists in a layout, defining for each sequence step the list of coordinates of each codebook timestep in the resulting interleaved sequence. The first item of the pattern is always an empty list in order to properly insert a special token to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern and ``timesteps`` the number of timesteps corresponding to the original sequence. The pattern provides convenient methods to build and revert interleaved sequences from it: ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, K being the number of codebooks, T the number of original timesteps and S the number of sequence steps for the output sequence. The unfilled positions are replaced with a special token and the built sequence is returned along with a mask indicating valid tokens. ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask to fill and specify invalid positions if needed. See the dedicated methods for more details. """ # Pattern layout, for each sequence step, we have a list of coordinates # corresponding to the original codebook timestep and position. # The first list is always an empty list in order to properly insert # a special token to start with. layout: PatternLayout timesteps: int n_q: int def __post_init__(self): assert len(self.layout) > 0 assert self.layout[0] == [] self._validate_layout() self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) def _validate_layout(self): """Runs checks on the layout to ensure a valid pattern is defined. A pattern is considered invalid if: - Multiple timesteps for a same codebook are defined in the same sequence step - The timesteps for a given codebook are not in ascending order as we advance in the sequence (this would mean that we have future timesteps before past timesteps). """ q_timesteps = {q: 0 for q in range(self.n_q)} for s, seq_coords in enumerate(self.layout): if len(seq_coords) > 0: qs = set() for coord in seq_coords: qs.add(coord.q) last_q_timestep = q_timesteps[coord.q] assert coord.t >= last_q_timestep, \ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" q_timesteps[coord.q] = coord.t # each sequence step contains at max 1 coordinate per codebook assert len(qs) == len(seq_coords), \ f"Multiple entries for a same codebook are found at step {s}" @property def num_sequence_steps(self): return len(self.layout) - 1 @property def max_delay(self): max_t_in_seq_coords = 0 for seq_coords in self.layout[1:]: for coords in seq_coords: max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) return max_t_in_seq_coords - self.timesteps @property def valid_layout(self): valid_step = len(self.layout) - self.max_delay return self.layout[:valid_step] def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): """Get codebook coordinates in the layout that corresponds to the specified timestep t and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step and the actual codebook coordinates. """ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" if q is not None: assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" coords = [] for s, seq_codes in enumerate(self.layout): for code in seq_codes: if code.t == t and (q is None or code.q == q): coords.append((s, code)) return coords def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: steps_with_timesteps = self.get_steps_with_timestep(t, q) return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, device: tp.Union[torch.device, str] = 'cpu'): """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. Args: timesteps (int): Maximum number of timesteps steps to consider. keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. device (torch.device or str): Device for created tensors. Returns: indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. """ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" # use the proper layout based on whether we limit ourselves to valid steps only or not, # note that using the valid_layout will result in a truncated sequence up to the valid steps ref_layout = self.valid_layout if keep_only_valid_steps else self.layout # single item indexing being super slow with pytorch vs. numpy, so we use numpy here indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() # fill indexes with last sequence step value that will correspond to our special token # the last value is n_q * timesteps as we have flattened z and append special token as the last token # which will correspond to the index: n_q * timesteps indexes[:] = n_q * timesteps # iterate over the pattern and fill scattered indexes and mask for s, sequence_coords in enumerate(ref_layout): for coords in sequence_coords: if coords.t < timesteps: indexes[coords.q, s] = coords.t + coords.q * timesteps mask[coords.q, s] = 1 indexes = torch.from_numpy(indexes).to(device) mask = torch.from_numpy(mask).to(device) return indexes, mask def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): """Build sequence corresponding to the pattern from the input tensor z. The sequence is built using up to sequence_steps if specified, and non-pattern coordinates are filled with the special token. Args: z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. special_token (int): Special token used to fill non-pattern coordinates in the new sequence. keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. Steps that are beyond valid steps will be replaced by the special_token in that case. Returns: values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. """ B, K, T = z.shape indexes, mask = self._build_pattern_sequence_scatter_indexes( T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) ) z = z.view(B, -1) # we append the special token as the last index of our flattened z tensor z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) values = z[:, indexes.view(-1)] values = values.view(B, K, indexes.shape[-1]) return values, indexes, mask def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, keep_only_valid_steps: bool = False, is_model_output: bool = False, device: tp.Union[torch.device, str] = 'cpu'): """Builds scatter indexes required to retrieve the original multi-codebook sequence from interleaving pattern. Args: sequence_steps (int): Sequence steps. n_q (int): Number of codebooks. keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. Steps that are beyond valid steps will be replaced by the special_token in that case. is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. device (torch.device or str): Device for created tensors. Returns: indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. """ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout # TODO(jade): Do we want to further truncate to only valid timesteps here as well? timesteps = self.timesteps assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" assert sequence_steps <= len(ref_layout), \ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" # ensure we take the appropriate indexes to keep the model output from the first special token as well if is_model_output: ref_layout = ref_layout[1:] # single item indexing being super slow with pytorch vs. numpy, so we use numpy here indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() # fill indexes with last sequence step value that will correspond to our special token indexes[:] = n_q * sequence_steps for s, sequence_codes in enumerate(ref_layout): if s < sequence_steps: for code in sequence_codes: if code.t < timesteps: indexes[code.q, code.t] = s + code.q * sequence_steps mask[code.q, code.t] = 1 indexes = torch.from_numpy(indexes).to(device) mask = torch.from_numpy(mask).to(device) return indexes, mask def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. The sequence is reverted using up to timesteps if specified, and non-pattern coordinates are filled with the special token. Args: s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. Returns: values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. """ B, K, S = s.shape indexes, mask = self._build_reverted_sequence_scatter_indexes( S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) ) s = s.view(B, -1) # we append the special token as the last index of our flattened z tensor s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) values = s[:, indexes.view(-1)] values = values.view(B, K, indexes.shape[-1]) return values, indexes, mask def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): """Revert model logits obtained on a sequence built from the pattern back to a tensor matching the original sequence. This method is similar to ``revert_pattern_sequence`` with the following specificities: 1. It is designed to work with the extra cardinality dimension 2. We return the logits for the first sequence item that matches the special_token and which matching target in the original sequence is the first item of the sequence, while we skip the last logits as there is no matching target """ B, card, K, S = logits.shape indexes, mask = self._build_reverted_sequence_scatter_indexes( S, K, keep_only_valid_steps, is_model_output=True, device=logits.device ) logits = logits.reshape(B, card, -1) # we append the special token as the last index of our flattened z tensor logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] values = logits[:, :, indexes.view(-1)] values = values.view(B, card, K, indexes.shape[-1]) return values, indexes, mask class CodebooksPatternProvider(ABC): """Abstraction around providing pattern for interleaving codebooks. The CodebooksPatternProvider abstraction allows to implement various strategies to define interleaving pattern of sequences composed of multiple codebooks. For a given number of codebooks `n_q`, the pattern provider can generate a specified pattern corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern can be used to construct a new sequence from the original codes respecting the specified pattern. The pattern is defined as a list of list of code coordinates, code coordinate being a tuple with the original timestep and codebook to build the new sequence. Note that all patterns must start with an empty list that is then used to insert a first sequence step of special tokens in the newly generated sequence. Args: n_q (int): number of codebooks. cached (bool): if True, patterns for a given length are cached. In general that should be true for efficiency reason to avoid synchronization points. """ def __init__(self, n_q: int, cached: bool = True): assert n_q > 0 self.n_q = n_q self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore @abstractmethod def get_pattern(self, timesteps: int) -> Pattern: """Builds pattern with specific interleaving between codebooks. Args: timesteps (int): Total number of timesteps. """ raise NotImplementedError() class DelayedPatternProvider(CodebooksPatternProvider): """Provider for delayed pattern across delayed codebooks. Codebooks are delayed in the sequence and sequence steps will contain codebooks from different timesteps. Example: Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] The resulting sequence obtained from the returned pattern is: [[S, 1, 2, 3, 4], [S, S, 1, 2, 3], [S, S, S, 1, 2]] (with S being a special token) Args: n_q (int): Number of codebooks. delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. flatten_first (int): Flatten the first N timesteps. empty_initial (int): Prepend with N empty list of coordinates. """ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, flatten_first: int = 0, empty_initial: int = 0): super().__init__(n_q) if delays is None: delays = list(range(n_q)) self.delays = delays self.flatten_first = flatten_first self.empty_initial = empty_initial assert len(self.delays) == self.n_q assert sorted(self.delays) == self.delays def get_pattern(self, timesteps: int) -> Pattern: out: PatternLayout = [[]] max_delay = max(self.delays) if self.empty_initial: out += [[] for _ in range(self.empty_initial)] if self.flatten_first: for t in range(min(timesteps, self.flatten_first)): for q in range(self.n_q): out.append([LayoutCoord(t, q)]) for t in range(self.flatten_first, timesteps + max_delay): v = [] for q, delay in enumerate(self.delays): t_for_q = t - delay if t_for_q >= self.flatten_first: v.append(LayoutCoord(t_for_q, q)) out.append(v) return Pattern(out, n_q=self.n_q, timesteps=timesteps) class ParallelPatternProvider(DelayedPatternProvider): """Provider for parallel pattern across codebooks. This pattern provider is a special case of the delayed pattern with actually no delay, hence delays=repeat(0, n_q). Args: n_q (int): Number of codebooks. """ def __init__(self, n_q: int): super().__init__(n_q, [0] * n_q) class UnrolledPatternProvider(CodebooksPatternProvider): """Provider for unrolling codebooks pattern. This pattern provider enables to represent the codebook flattened completely or only to some extend while also specifying a given delay between the flattened codebooks representation, allowing to unroll the codebooks in the sequence. Example: 1. Flattening of the codebooks. By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), taking n_q = 3 and timesteps = 4: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] will result into: [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], [S, 1, S, S, 2, S, S, 3, S, S, 4, S], [1, S, S, 2, S, S, 3, S, S, 4, S, S]] 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] will result into: [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], [S, 1, S, S, 2, S, S, 3, S, S, 4, S], [1, S, S, 2, S, S, 3, S, S, 4, S, S]] 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] and delays = [0, 3, 3]: [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] will result into: [[S, S, S, 1, S, 2, S, 3, S, 4], [S, S, S, 1, S, 2, S, 3, S, 4], [1, 2, 3, S, 4, S, 5, S, 6, S]] Args: n_q (int): Number of codebooks. flattening (list of int, optional): Flattening schema over the codebooks. If not defined, the codebooks will be flattened to 1 codebook per step, meaning that the sequence will have n_q extra steps for each timestep. delays (list of int, optional): Delay for each of the codebooks. If not defined, no delay is added and therefore will default to [0] * ``n_q``. Note that two codebooks that will be flattened to the same inner step should have the same delay, otherwise the pattern is considered as invalid. """ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, delays: tp.Optional[tp.List[int]] = None): super().__init__(n_q) if flattening is None: flattening = list(range(n_q)) if delays is None: delays = [0] * n_q assert len(flattening) == n_q assert len(delays) == n_q assert sorted(flattening) == flattening assert sorted(delays) == delays self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) self.max_delay = max(delays) def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): """Build a flattened codebooks representation as a dictionary of inner step and the actual codebook indices corresponding to the flattened codebook. For convenience, we also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. """ flattened_codebooks: dict = {} for q, (inner_step, delay) in enumerate(zip(flattening, delays)): if inner_step not in flattened_codebooks: flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) else: flat_codebook = flattened_codebooks[inner_step] assert flat_codebook.delay == delay, ( "Delay and flattening between codebooks is inconsistent: ", "two codebooks flattened to the same position should have the same delay." ) flat_codebook.codebooks.append(q) flattened_codebooks[inner_step] = flat_codebook return flattened_codebooks @property def _num_inner_steps(self): """Number of inner steps to unroll between timesteps in order to flatten the codebooks. """ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 def num_virtual_steps(self, timesteps: int) -> int: return timesteps * self._num_inner_steps + 1 def get_pattern(self, timesteps: int) -> Pattern: """Builds pattern for delay across codebooks. Args: timesteps (int): Total number of timesteps. """ # the PatternLayout is built as a tuple of sequence position and list of coordinates # so that it can be reordered properly given the required delay between codebooks of given timesteps indexed_out: list = [(-1, [])] max_timesteps = timesteps + self.max_delay for t in range(max_timesteps): # for each timestep, we unroll the flattened codebooks, # emitting the sequence step with the corresponding delay for step in range(self._num_inner_steps): if step in self._flattened_codebooks: # we have codebooks at this virtual step to emit step_codebooks = self._flattened_codebooks[step] t_for_q = t + step_codebooks.delay coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] if t_for_q < max_timesteps and t < max_timesteps: indexed_out.append((t_for_q, coords)) else: # there is no codebook in this virtual step so we emit an empty list indexed_out.append((t, [])) out = [coords for _, coords in sorted(indexed_out)] return Pattern(out, n_q=self.n_q, timesteps=timesteps) class VALLEPattern(CodebooksPatternProvider): """Almost VALL-E style pattern. We further allow some delays for the codebooks other than the first one. Args: n_q (int): Number of codebooks. delays (list of int, optional): Delay for each of the codebooks. If delays not defined, each codebook is delayed by 1 compared to the previous one. """ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): super().__init__(n_q) if delays is None: delays = [0] * (n_q - 1) self.delays = delays assert len(self.delays) == self.n_q - 1 assert sorted(self.delays) == self.delays def get_pattern(self, timesteps: int) -> Pattern: out: PatternLayout = [[]] for t in range(timesteps): out.append([LayoutCoord(t, 0)]) max_delay = max(self.delays) for t in range(timesteps + max_delay): v = [] for q, delay in enumerate(self.delays): t_for_q = t - delay if t_for_q >= 0: v.append(LayoutCoord(t_for_q, q + 1)) out.append(v) return Pattern(out, n_q=self.n_q, timesteps=timesteps) class MusicLMPattern(CodebooksPatternProvider): """Almost MusicLM style pattern. This is equivalent to full flattening but in a different order. Args: n_q (int): Number of codebooks. group_by (int): Number of codebooks to group together. """ def __init__(self, n_q: int, group_by: int = 2): super().__init__(n_q) self.group_by = group_by def get_pattern(self, timesteps: int) -> Pattern: out: PatternLayout = [[]] for offset in range(0, self.n_q, self.group_by): for t in range(timesteps): for q in range(offset, offset + self.group_by): out.append([LayoutCoord(t, q)]) return Pattern(out, n_q=self.n_q, timesteps=timesteps) ================================================ FILE: audiocraft/modules/conditioners.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import defaultdict from copy import deepcopy from dataclasses import dataclass, field from itertools import chain import logging import math from pathlib import Path import random import re import typing as tp import warnings import einops from num2words import num2words import spacy from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore import torch from torch import nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from .chroma import ChromaExtractor from .streaming import StreamingModule from .transformer import create_sin_embedding from ..data.audio import audio_read from ..data.audio_dataset import SegmentInfo from ..data.audio_utils import convert_audio from ..environment import AudioCraftEnvironment from ..quantization import ResidualVectorQuantizer from ..utils.autocast import TorchAutocast from ..utils.cache import EmbeddingCache from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once logger = logging.getLogger(__name__) TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist) ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask class WavCondition(tp.NamedTuple): wav: torch.Tensor length: torch.Tensor sample_rate: tp.List[int] path: tp.List[tp.Optional[str]] = [] seek_time: tp.List[tp.Optional[float]] = [] class JointEmbedCondition(tp.NamedTuple): wav: torch.Tensor text: tp.List[tp.Optional[str]] length: torch.Tensor sample_rate: tp.List[int] path: tp.List[tp.Optional[str]] = [] seek_time: tp.List[tp.Optional[float]] = [] @dataclass class ConditioningAttributes: text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) wav: tp.Dict[str, WavCondition] = field(default_factory=dict) joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) def __getitem__(self, item): return getattr(self, item) @property def text_attributes(self): return self.text.keys() @property def wav_attributes(self): return self.wav.keys() @property def joint_embed_attributes(self): return self.joint_embed.keys() @property def attributes(self): return { "text": self.text_attributes, "wav": self.wav_attributes, "joint_embed": self.joint_embed_attributes, } def to_flat_dict(self): return { **{f"text.{k}": v for k, v in self.text.items()}, **{f"wav.{k}": v for k, v in self.wav.items()}, **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} } @classmethod def from_flat_dict(cls, x): out = cls() for k, v in x.items(): kind, att = k.split(".") out[kind][att] = v return out class SegmentWithAttributes(SegmentInfo): """Base class for all dataclasses that are used for conditioning. All child classes should implement `to_condition_attributes` that converts the existing attributes to a dataclass of type ConditioningAttributes. """ def to_condition_attributes(self) -> ConditioningAttributes: raise NotImplementedError() def nullify_condition(condition: ConditionType, dim: int = 1): """Transform an input condition to a null condition. The way it is done by converting it to a single zero vector similarly to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. Args: condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) dim (int): The dimension that will be truncated (should be the time dimension) WARNING!: dim should not be the batch dimension! Returns: ConditionType: A tuple of null condition and mask """ assert dim != 0, "dim cannot be the batch dimension!" assert isinstance(condition, tuple) and \ isinstance(condition[0], torch.Tensor) and \ isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" cond, mask = condition B = cond.shape[0] last_dim = cond.dim() - 1 out = cond.transpose(dim, last_dim) out = 0. * out[..., :1] out = out.transpose(dim, last_dim) mask = torch.zeros((B, 1), device=out.device).int() assert cond.dim() == out.dim() return out, mask def nullify_wav(cond: WavCondition) -> WavCondition: """Transform a WavCondition to a nullified WavCondition. It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. Args: cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. Returns: WavCondition: Nullified wav condition. """ null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) return WavCondition( wav=null_wav, length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), sample_rate=cond.sample_rate, path=[None] * cond.wav.shape[0], seek_time=[None] * cond.wav.shape[0], ) def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, and replacing metadata by dummy attributes. Args: cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. """ null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) return JointEmbedCondition( wav=null_wav, text=[None] * len(embed.text), length=torch.LongTensor([0]).to(embed.wav.device), sample_rate=embed.sample_rate, path=[None] * embed.wav.shape[0], seek_time=[0] * embed.wav.shape[0], ) class Tokenizer: """Base tokenizer implementation (in case we want to introduce more advances tokenizers in the future). """ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() class WhiteSpaceTokenizer(Tokenizer): """This tokenizer should be used for natural language descriptions. For example: ["he didn't, know he's going home.", 'shorter sentence'] => [[78, 62, 31, 4, 78, 25, 19, 34], [59, 77, 0, 0, 0, 0, 0, 0]] """ PUNCTUATION = "?:!.,;" def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", lemma: bool = True, stopwords: bool = True) -> None: self.n_bins = n_bins self.pad_idx = pad_idx self.lemma = lemma self.stopwords = stopwords try: self.nlp = spacy.load(language) except IOError: spacy.cli.download(language) # type: ignore self.nlp = spacy.load(language) @tp.no_type_check def __call__(self, texts: tp.List[tp.Optional[str]], return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Take a list of strings and convert them to a tensor of indices. Args: texts (list[str]): List of strings. return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. Returns: tuple[torch.Tensor, torch.Tensor]: - Indices of words in the LUT. - And a mask indicating where the padding tokens are """ output, lengths = [], [] texts = deepcopy(texts) for i, text in enumerate(texts): # if current sample doesn't have a certain attribute, replace with pad token if text is None: output.append(torch.Tensor([self.pad_idx])) lengths.append(0) continue # convert numbers to words text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore # normalize text text = self.nlp(text) # type: ignore # remove stopwords if self.stopwords: text = [w for w in text if not w.is_stop] # type: ignore # remove punctuation text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore # lemmatize if needed text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore texts[i] = " ".join(text) lengths.append(len(text)) # convert to tensor tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) output.append(tokens) mask = length_to_mask(torch.IntTensor(lengths)).int() padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() if return_text: return padded_output, mask, texts # type: ignore return padded_output, mask class NoopTokenizer(Tokenizer): """This tokenizer should be used for global conditioners such as: artist, genre, key, etc. The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will split it to ["Jeff", "Buckley"] and return an index per word. For example: ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] ["Metal", "Rock", "Classical"] => [0, 223, 51] """ def __init__(self, n_bins: int, pad_idx: int = 0): self.n_bins = n_bins self.pad_idx = pad_idx def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: output, lengths = [], [] for text in texts: # if current sample doesn't have a certain attribute, replace with pad token if text is None: output.append(self.pad_idx) lengths.append(0) else: output.append(hash_trick(text, self.n_bins)) lengths.append(1) tokens = torch.LongTensor(output).unsqueeze(1) mask = length_to_mask(torch.IntTensor(lengths)).int() return tokens, mask class BaseConditioner(nn.Module): """Base model for all conditioner modules. We allow the output dim to be different than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; 2) make all condition dims consistent. Args: dim (int): Hidden dim of the model. output_dim (int): Output dim of the conditioner. """ def __init__(self, dim: int, output_dim: int): super().__init__() self.dim = dim self.output_dim = output_dim self.output_proj = nn.Linear(dim, output_dim) def tokenize(self, *args, **kwargs) -> tp.Any: """Should be any part of the processing that will lead to a synchronization point, e.g. BPE tokenization with transfer to the GPU. The returned value will be saved and return later when calling forward(). """ raise NotImplementedError() def forward(self, inputs: tp.Any) -> ConditionType: """Gets input that should be used as conditioning (e.g, genre, description or a waveform). Outputs a ConditionType, after the input data was embedded as a dense vector. Returns: ConditionType: - A tensor of size [B, T, D] where B is the batch size, T is the length of the output embedding and D is the dimension of the embedding. - And a mask indicating where the padding tokens. """ raise NotImplementedError() class TextConditioner(BaseConditioner): ... class LUTConditioner(TextConditioner): """Lookup table TextConditioner. Args: n_bins (int): Number of bins. dim (int): Hidden dim of the model (text-encoder/LUT). output_dim (int): Output dim of the conditioner. tokenizer (str): Name of the tokenizer. pad_idx (int, optional): Index for padding token. Defaults to 0. """ def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): super().__init__(dim, output_dim) self.embed = nn.Embedding(n_bins, dim) self.tokenizer: Tokenizer if tokenizer == 'whitespace': self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) elif tokenizer == 'noop': self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) else: raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: device = self.embed.weight.device tokens, mask = self.tokenizer(x) tokens, mask = tokens.to(device), mask.to(device) return tokens, mask def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: tokens, mask = inputs embeds = self.embed(tokens) embeds = self.output_proj(embeds) embeds = (embeds * mask.unsqueeze(-1)) return embeds, mask class T5Conditioner(TextConditioner): """T5-based TextConditioner. Args: name (str): Name of the T5 model. output_dim (int): Output dim of the conditioner. finetune (bool): Whether to fine-tune T5 at train time. device (str): Device for T5 Conditioner. autocast_dtype (tp.Optional[str], optional): Autocast dtype. word_dropout (float, optional): Word dropout probability. normalize_text (bool, optional): Whether to apply text normalization. """ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"] MODELS_DIMS = { "t5-small": 512, "t5-base": 768, "t5-large": 1024, "t5-3b": 1024, "t5-11b": 1024, "google/flan-t5-small": 512, "google/flan-t5-base": 768, "google/flan-t5-large": 1024, "google/flan-t5-3b": 1024, "google/flan-t5-11b": 1024, } def __init__(self, name: str, output_dim: int, finetune: bool, device: str, autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., normalize_text: bool = False): assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" super().__init__(self.MODELS_DIMS[name], output_dim) self.device = device self.name = name self.finetune = finetune self.word_dropout = word_dropout if autocast_dtype is None or self.device == 'cpu': self.autocast = TorchAutocast(enabled=False) if self.device != 'cpu': logger.warning("T5 has no autocast, this might lead to NaN") else: dtype = getattr(torch, autocast_dtype) assert isinstance(dtype, torch.dtype) logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) # Let's disable logging temporarily because T5 will vomit some errors otherwise. # thanks https://gist.github.com/simon-weber/7853144 previous_level = logging.root.manager.disable logging.disable(logging.ERROR) with warnings.catch_warnings(): warnings.simplefilter("ignore") try: self.t5_tokenizer = T5Tokenizer.from_pretrained(name) t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) finally: logging.disable(previous_level) if finetune: self.t5 = t5 else: # this makes sure that the t5 models is not part # of the saved checkpoint self.__dict__['t5'] = t5.to(device) self.normalize_text = normalize_text if normalize_text: self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: # if current sample doesn't have a certain attribute, replace with empty string entries: tp.List[str] = [xi if xi is not None else "" for xi in x] if self.normalize_text: _, _, entries = self.text_normalizer(entries, return_text=True) if self.word_dropout > 0. and self.training: new_entries = [] for entry in entries: words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] new_entries.append(" ".join(words)) entries = new_entries empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) mask = inputs['attention_mask'] mask[empty_idx, :] = 0 # zero-out index where the input is non-existant return inputs def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: mask = inputs['attention_mask'] with torch.set_grad_enabled(self.finetune), self.autocast: embeds = self.t5(**inputs).last_hidden_state embeds = self.output_proj(embeds.to(self.output_proj.weight)) embeds = (embeds * mask.unsqueeze(-1)) return embeds, mask class WaveformConditioner(BaseConditioner): """Base class for all conditioners that take a waveform as input. Classes that inherit must implement `_get_wav_embedding` that outputs a continuous tensor, and `_downsampling_factor` that returns the down-sampling factor of the embedding model. Args: dim (int): The internal representation dimension. output_dim (int): Output dimension. device (tp.Union[torch.device, str]): Device. """ def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): super().__init__(dim, output_dim) self.device = device def tokenize(self, x: WavCondition) -> WavCondition: wav, length, sample_rate, path, seek_time = x assert length is not None return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: """Gets as input a WavCondition and returns a dense embedding.""" raise NotImplementedError() def _downsampling_factor(self): """Returns the downsampling factor of the embedding model.""" raise NotImplementedError() def forward(self, x: WavCondition) -> ConditionType: """Extract condition embedding and mask from a waveform and its metadata. Args: x (WavCondition): Waveform condition containing raw waveform and metadata. Returns: ConditionType: a dense vector representing the conditioning along with its mask """ wav, lengths, *_ = x with torch.no_grad(): embeds = self._get_wav_embedding(x) embeds = embeds.to(self.output_proj.weight) embeds = self.output_proj(embeds) if lengths is not None: lengths = lengths / self._downsampling_factor() mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore else: mask = torch.ones_like(embeds) embeds = (embeds * mask.unsqueeze(2).to(self.device)) return embeds, mask class ChromaStemConditioner(WaveformConditioner): """Chroma conditioner based on stems. The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as the drums and bass often dominate the chroma leading to the chroma features not containing information about the melody. Args: output_dim (int): Output dimension for the conditioner. sample_rate (int): Sample rate for the chroma extractor. n_chroma (int): Number of chroma bins for the chroma extractor. radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12). duration (int): duration used during training. This is later used for correct padding in case we are using chroma as prefix. match_len_on_eval (bool, optional): if True then all chromas are padded to the training duration. Defaults to False. eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). Defaults to None. n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0. device (tp.Union[torch.device, str], optional): Device for the conditioner. **kwargs: Additional parameters for the chroma extractor. """ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None, n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None, device: tp.Union[torch.device, str] = 'cpu', **kwargs): from demucs import pretrained super().__init__(dim=n_chroma, output_dim=output_dim, device=device) self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) self.sample_rate = sample_rate self.match_len_on_eval = match_len_on_eval self.duration = duration self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) stem_sources: list = self.demucs.sources # type: ignore self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device) self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp, **kwargs).to(device) self.chroma_len = self._get_chroma_len() self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs) self.cache = None if cache_path is not None: self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, compute_embed_fn=self._get_full_chroma_for_cache, extract_embed_fn=self._extract_chroma_chunk) def _downsampling_factor(self) -> int: return self.chroma.winhop def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]: """Load pre-defined waveforms from a json. These waveforms will be used for chroma extraction during evaluation. This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps). """ if path is None: return None logger.info(f"Loading evaluation wavs from {path}") from audiocraft.data.audio_dataset import AudioDataset dataset: AudioDataset = AudioDataset.from_meta( path, segment_duration=self.duration, min_audio_duration=self.duration, sample_rate=self.sample_rate, channels=1) if len(dataset) > 0: eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device) logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner") return eval_wavs else: raise ValueError("Could not find evaluation wavs, check lengths of wavs") def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None: self.eval_wavs = eval_wavs def has_eval_wavs(self) -> bool: return self.eval_wavs is not None def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor: """Sample wavs from a predefined list.""" assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided." total_eval_wavs = len(self.eval_wavs) out = self.eval_wavs if num_samples > total_eval_wavs: out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1) return out[torch.randperm(len(out))][:num_samples] def _get_chroma_len(self) -> int: """Get length of chroma during training.""" dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device) dummy_chr = self.chroma(dummy_wav) return dummy_chr.shape[1] @torch.no_grad() def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: """Get parts of the wav that holds the melody, extracting the main stems from the wav.""" from demucs.apply import apply_model from demucs.audio import convert_audio with self.autocast: wav = convert_audio( wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore stems = apply_model(self.demucs, wav, device=self.device) stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning mix_wav = stems.sum(1) # merge extracted stems to single waveform mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore return mix_wav @torch.no_grad() def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor: """Extract chroma features from the waveform.""" with self.autocast: return self.chroma(wav) @torch.no_grad() def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: """Compute wav embedding, applying stem and chroma extraction.""" # avoid 0-size tensors when we are working with null conds if wav.shape[-1] == 1: return self._extract_chroma(wav) stems = self._get_stemmed_wav(wav, sample_rate) chroma = self._extract_chroma(stems) return chroma @torch.no_grad() def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor: """Extract chroma from the whole audio waveform at the given path.""" wav, sr = audio_read(path) wav = wav[None].to(self.device) wav = convert_audio(wav, sr, self.sample_rate, to_channels=1) chroma = self._compute_wav_embedding(wav, self.sample_rate)[0] return chroma def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor: """Extract a chunk of chroma from the full chroma derived from the full waveform.""" wav_length = x.wav.shape[-1] seek_time = x.seek_time[idx] assert seek_time is not None, ( "WavCondition seek_time is required " "when extracting chroma chunks from pre-computed chroma.") full_chroma = full_chroma.float() frame_rate = self.sample_rate / self._downsampling_factor() target_length = int(frame_rate * wav_length / self.sample_rate) index = int(frame_rate * seek_time) out = full_chroma[index: index + target_length] out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0] return out.to(self.device) @torch.no_grad() def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: """Get the wav embedding from the WavCondition. The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly or will rely on the embedding cache to load the pre-computed embedding if relevant. """ sampled_wav: tp.Optional[torch.Tensor] = None if not self.training and self.eval_wavs is not None: warn_once(logger, "Using precomputed evaluation wavs!") sampled_wav = self._sample_eval_wavs(len(x.wav)) no_undefined_paths = all(p is not None for p in x.path) no_nullified_cond = x.wav.shape[-1] > 1 if sampled_wav is not None: chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate) elif self.cache is not None and no_undefined_paths and no_nullified_cond: paths = [Path(p) for p in x.path if p is not None] chroma = self.cache.get_embed_from_cache(paths, x) else: assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal." chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0]) if self.match_len_on_eval: B, T, C = chroma.shape if T > self.chroma_len: chroma = chroma[:, :self.chroma_len] logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})") elif T < self.chroma_len: n_repeat = int(math.ceil(self.chroma_len / T)) chroma = chroma.repeat(1, n_repeat, 1) chroma = chroma[:, :self.chroma_len] logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})") return chroma def tokenize(self, x: WavCondition) -> WavCondition: """Apply WavConditioner tokenization and populate cache if needed.""" x = super().tokenize(x) no_undefined_paths = all(p is not None for p in x.path) if self.cache is not None and no_undefined_paths: paths = [Path(p) for p in x.path if p is not None] self.cache.populate_embed_cache(paths, x) return x class JointEmbeddingConditioner(BaseConditioner): """Joint embedding conditioning supporting both audio or text conditioning. Args: dim (int): Dimension. output_dim (int): Output dimension. device (str): Device. attribute (str): Attribute used by the conditioner. autocast_dtype (str): Autocast for the conditioner. quantize (bool): Whether to quantize the CLAP embedding. n_q (int): Number of residual quantizers (used if quantize is true). bins (int): Quantizers' codebooks size (used if quantize is true). kwargs: Additional parameters for residual vector quantizer. """ def __init__(self, dim: int, output_dim: int, device: str, attribute: str, autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, n_q: int = 12, bins: int = 1024, **kwargs): super().__init__(dim=dim, output_dim=output_dim) self.device = device self.attribute = attribute if autocast_dtype is None or device == 'cpu': self.autocast = TorchAutocast(enabled=False) logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") else: dtype = getattr(torch, autocast_dtype) assert isinstance(dtype, torch.dtype) logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) # residual vector quantizer to discretize the conditioned embedding self.quantizer: tp.Optional[ResidualVectorQuantizer] = None if quantize: self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs) def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Get joint embedding in latent space from the inputs. Returns: tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding and corresponding empty indexes. """ raise NotImplementedError() def forward(self, x: JointEmbedCondition) -> ConditionType: with self.autocast: embed, empty_idx = self._get_embed(x) if self.quantizer is not None: embed = embed.view(-1, self.dim, 1) q_res = self.quantizer(embed, frame_rate=1) out_embed = q_res.x.view(-1, self.dim) else: out_embed = embed out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) mask[empty_idx, :] = 0 # zero-out index where the input is non-existant out_embed = (out_embed * mask.unsqueeze(-1)) return out_embed, mask def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: return x class CLAPEmbeddingConditioner(JointEmbeddingConditioner): """Joint Embedding conditioner based on pre-trained CLAP model. This CLAP-based conditioner supports a caching mechanism over the computed embeddings for faster training. Args: dim (int): Dimension. output_dim (int): Output dimension. device (str): Device. attribute (str): Attribute used by the conditioner. quantize (bool): Whether to quantize the CLAP embedding. n_q (int): Number of residual quantizers (used if quantize is true). bins (int): Quantizers' codebooks size (used if quantize is true). checkpoint (str): Path to CLAP checkpoint. model_arch (str): CLAP model architecture. enable_fusion (bool): Enable fusion for CLAP model. sample_rate (int): Sample rate used by CLAP model. max_audio_length (float): Maximum audio length for CLAP model. audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. normalize (bool): Whether to normalize the CLAP embedding. text_p (float): Probability of using text representation instead of audio at train time. batch_size (Optional[int]): Batch size for CLAP embedding computation. autocast_dtype (str): Autocast for the conditioner. cache_path (Optional[str]): Path for pre-computed embeddings caching. kwargs: Additional parameters for residual vector quantizer. """ def __init__(self, dim: int, output_dim: int, device: str, attribute: str, quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): try: import laion_clap # type: ignore except ImportError: raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) load_clap_state_dict(clap_model, checkpoint) clap_model.eval() clap_model.to(device) super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, **kwargs) self.checkpoint = checkpoint self.enable_fusion = enable_fusion self.model_arch = model_arch self.clap: laion_clap.CLAP_Module self.clap_tokenize: RobertaTokenizer self.clap_sample_rate = sample_rate self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) self.clap_stride = int(self.clap_sample_rate * audio_stride) self.batch_size = batch_size or 1 self.normalize = normalize self.text_p = text_p self.__dict__['clap_tokenize'] = clap_tokenize self.__dict__['clap'] = clap_model self.wav_cache, self.text_cache = None, None if cache_path is not None: self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, compute_embed_fn=self._get_wav_embedding_for_cache, extract_embed_fn=self._extract_wav_embedding_chunk) self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, compute_embed_fn=self._get_text_embedding_for_cache) def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: # we use the default params from CLAP module here as well return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: """Compute text embedding from CLAP model on a given a batch of text. Args: text (list[str]): List of text for the batch, with B items. Returns: torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. """ with torch.no_grad(): embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) return embed.view(embed.size(0), 1, embed.size(-1)) def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], x: JointEmbedCondition, idx: int) -> torch.Tensor: """Get text embedding function for the cache.""" text = x.text[idx] text = text if text is not None else "" return self._compute_text_embedding([text])[0] def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: """Preprocess wav to expected format by CLAP model. Args: wav (torch.Tensor): Audio wav, of shape [B, C, T]. length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. sample_rates (list[int]): Sample rates for each sample in the batch Returns: torch.Tensor: Audio wav of shape [B, T]. """ assert wav.dim() == 3, "Expecting wav to be [B, C, T]" if sample_rates is not None: _wav = [] for i, audio in enumerate(wav): sr = sample_rates[i] audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) _wav.append(audio) wav = torch.stack(_wav, dim=0) wav = wav.mean(dim=1) return wav def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: """Compute audio wave embedding from CLAP model. Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and average the resulting embeddings. Args: wav (torch.Tensor): Audio wav, of shape [B, C, T]. length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. sample_rates (list[int]): Sample rates for each sample in the batch. reduce_mean (bool): Whether to get the average tensor. Returns: torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. """ with torch.no_grad(): wav = self._preprocess_wav(wav, length, sample_rates) B, T = wav.shape if T >= self.clap_max_frames: wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T] else: wav = wav.view(-1, 1, T) # [B, F, T] with F=1 wav = einops.rearrange(wav, 'b f t -> (b f) t') embed_list = [] for i in range(0, wav.size(0), self.batch_size): _wav = wav[i:i+self.batch_size, ...] _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) embed_list.append(_embed) embed = torch.cat(embed_list, dim=0) embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) if reduce_mean: embed = embed.mean(dim=1, keepdim=True) return embed # [B, F, D] with F=1 if reduce_mean is True def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], x: JointEmbedCondition, idx: int) -> torch.Tensor: """Compute audio wave embedding for the cache. The embedding is computed on a given audio read from file. Args: path (str or Path): Path to the full audio file. Returns: torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. """ wav, sr = audio_read(path) # [C, T] wav = wav.unsqueeze(0).to(self.device) # [1, C, T] wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D] return embed.squeeze(0) # [F, D] def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. Args: full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. x (JointEmbedCondition): Joint embedding condition for the full batch. idx (int): Index considered for the given embedding to extract. Returns: torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. """ sample_rate = x.sample_rate[idx] seek_time = x.seek_time[idx] seek_time = 0. if seek_time is None else seek_time clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate start_offset = int(seek_time * sample_rate // clap_stride) end_offset = int(end_seek_time * sample_rate // clap_stride) wav_embed = full_embed[start_offset:end_offset, ...] wav_embed = wav_embed.mean(dim=0, keepdim=True) return wav_embed.to(self.device) # [F, D] def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: """Get CLAP embedding from a batch of text descriptions.""" no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout if self.text_cache is not None and no_nullified_cond: assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" paths = [Path(p) for p in x.path if p is not None] embed = self.text_cache.get_embed_from_cache(paths, x) else: text = [xi if xi is not None else "" for xi in x.text] embed = self._compute_text_embedding(text) if self.normalize: embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) return embed def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" no_undefined_paths = all(p is not None for p in x.path) no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: paths = [Path(p) for p in x.path if p is not None] embed = self.wav_cache.get_embed_from_cache(paths, x) else: embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) if self.normalize: embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) return embed def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: # Trying to limit as much as possible sync points when the cache is warm. no_undefined_paths = all(p is not None for p in x.path) if self.wav_cache is not None and no_undefined_paths: assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" paths = [Path(p) for p in x.path if p is not None] self.wav_cache.populate_embed_cache(paths, x) if self.text_cache is not None and no_undefined_paths: assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" paths = [Path(p) for p in x.path if p is not None] self.text_cache.populate_embed_cache(paths, x) return x def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Extract shared latent representation from either the wav or the text using CLAP.""" # decide whether to use text embedding at train time or not use_text_embed = random.random() < self.text_p if self.training and not use_text_embed: embed = self._get_wav_embedding(x) empty_idx = torch.LongTensor([]) # we assume we always have the audio wav else: embed = self._get_text_embedding(x) empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) return embed, empty_idx def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: """Utility function for nullifying an attribute inside an ConditioningAttributes object. If the condition is of type "wav", then nullify it using `nullify_condition` function. If the condition is of any other type, set its value to None. Works in-place. """ if condition_type not in ['text', 'wav', 'joint_embed']: raise ValueError( "dropout_condition got an unexpected condition type!" f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" ) if condition not in getattr(sample, condition_type): raise ValueError( "dropout_condition received an unexpected condition!" f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" f" but got '{condition}' of type '{condition_type}'!" ) if condition_type == 'wav': wav_cond = sample.wav[condition] sample.wav[condition] = nullify_wav(wav_cond) elif condition_type == 'joint_embed': embed = sample.joint_embed[condition] sample.joint_embed[condition] = nullify_joint_embed(embed) else: sample.text[condition] = None return sample class DropoutModule(nn.Module): """Base module for all dropout modules.""" def __init__(self, seed: int = 1234): super().__init__() self.rng = torch.Generator() self.rng.manual_seed(seed) class AttributeDropout(DropoutModule): """Dropout with a given probability per attribute. This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example, "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" must also be dropped. Args: p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: ... "genre": 0.1, "artist": 0.5, "wav": 0.25, ... active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. seed (int, optional): Random seed. """ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): super().__init__(seed=seed) self.active_on_eval = active_on_eval # construct dict that return the values from p otherwise 0 self.p = {} for condition_type, probs in p.items(): self.p[condition_type] = defaultdict(lambda: 0, probs) def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """ Args: samples (list[ConditioningAttributes]): List of conditions. Returns: list[ConditioningAttributes]: List of conditions after certain attributes were set to None. """ if not self.training and not self.active_on_eval: return samples samples = deepcopy(samples) for condition_type, ps in self.p.items(): # for condition types [text, wav] for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) if torch.rand(1, generator=self.rng).item() < p: for sample in samples: dropout_condition(sample, condition_type, condition) return samples def __repr__(self): return f"AttributeDropout({dict(self.p)})" class ClassifierFreeGuidanceDropout(DropoutModule): """Classifier Free Guidance dropout. All attributes are dropped with the same probability. Args: p (float): Probability to apply condition dropout during training. seed (int): Random seed. """ def __init__(self, p: float, seed: int = 1234): super().__init__(seed=seed) self.p = p def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: """ Args: samples (list[ConditioningAttributes]): List of conditions. Returns: list[ConditioningAttributes]: List of conditions after all attributes were set to None. """ if not self.training: return samples # decide on which attributes to drop in a batched fashion drop = torch.rand(1, generator=self.rng).item() < self.p if not drop: return samples # nullify conditions of all attributes samples = deepcopy(samples) for condition_type in ["wav", "text"]: for sample in samples: for condition in sample.attributes[condition_type]: dropout_condition(sample, condition_type, condition) return samples def __repr__(self): return f"ClassifierFreeGuidanceDropout(p={self.p})" class ConditioningProvider(nn.Module): """Prepare and provide conditions given all the supported conditioners. Args: conditioners (dict): Dictionary of conditioners. device (torch.device or str, optional): Device for conditioners and output condition types. """ def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): super().__init__() self.device = device self.conditioners = nn.ModuleDict(conditioners) @property def joint_embed_conditions(self): return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] @property def has_joint_embed_conditions(self): return len(self.joint_embed_conditions) > 0 @property def text_conditions(self): return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] @property def wav_conditions(self): return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] @property def has_wav_condition(self): return len(self.wav_conditions) > 0 def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. This should be called before starting any real GPU work to avoid synchronization points. This will return a dict matching conditioner names to their arbitrary tokenized representations. Args: inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing text and wav conditions. """ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", f" but types were {set([type(x) for x in inputs])}" ) output = {} text = self._collate_text(inputs) wavs = self._collate_wavs(inputs) joint_embeds = self._collate_joint_embeds(inputs) assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" ) for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()): output[attribute] = self.conditioners[attribute].tokenize(batch) return output def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. The output is for example: { "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), ... } Args: tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. """ output = {} for attribute, inputs in tokenized.items(): condition, mask = self.conditioners[attribute](inputs) output[attribute] = (condition, mask) return output def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: """Given a list of ConditioningAttributes objects, compile a dictionary where the keys are the attributes and the values are the aggregated input per attribute. For example: Input: [ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), ] Output: { "genre": ["Rock", "Hip-hop"], "description": ["A rock song with a guitar solo", "A hip-hop verse"] } Args: samples (list of ConditioningAttributes): List of ConditioningAttributes samples. Returns: dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. """ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) texts = [x.text for x in samples] for text in texts: for condition in self.text_conditions: out[condition].append(text[condition]) return out def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: """Generate a dict where the keys are attributes by which we fetch similar wavs, and the values are Tensors of wavs according to said attributes. *Note*: by the time the samples reach this function, each sample should have some waveform inside the "wav" attribute. It should be either: 1. A real waveform 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) 3. A null waveform due to it being dropped in a dropout module (nullified by dropout) Args: samples (list of ConditioningAttributes): List of ConditioningAttributes samples. Returns: dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. """ wavs = defaultdict(list) lengths = defaultdict(list) sample_rates = defaultdict(list) paths = defaultdict(list) seek_times = defaultdict(list) out: tp.Dict[str, WavCondition] = {} for sample in samples: for attribute in self.wav_conditions: wav, length, sample_rate, path, seek_time = sample.wav[attribute] assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" # mono-channel conditioning wav = wav.mean(1, keepdim=True) # [1, 1, T] wavs[attribute].append(wav.flatten()) # [T] lengths[attribute].append(length) sample_rates[attribute].extend(sample_rate) paths[attribute].extend(path) seek_times[attribute].extend(seek_time) # stack all wavs to a single tensor for attribute in self.wav_conditions: stacked_wav, _ = collate(wavs[attribute], dim=0) out[attribute] = WavCondition( stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], paths[attribute], seek_times[attribute]) return out def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: """Generate a dict where the keys are attributes by which we compute joint embeddings, and the values are Tensors of pre-computed embeddings and the corresponding text attributes. Args: samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. Returns: A dictionary mapping an attribute name to joint embeddings. """ texts = defaultdict(list) wavs = defaultdict(list) lengths = defaultdict(list) sample_rates = defaultdict(list) paths = defaultdict(list) seek_times = defaultdict(list) channels: int = 0 out = {} for sample in samples: for attribute in self.joint_embed_conditions: wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] assert wav.dim() == 3 if channels == 0: channels = wav.size(1) else: assert channels == wav.size(1), "not all audio has same number of channels in batch" assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T] wavs[attribute].append(wav) texts[attribute].extend(text) lengths[attribute].append(length) sample_rates[attribute].extend(sample_rate) paths[attribute].extend(path) seek_times[attribute].extend(seek_time) for attribute in self.joint_embed_conditions: stacked_texts = texts[attribute] stacked_paths = paths[attribute] stacked_seek_times = seek_times[attribute] stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) stacked_sample_rates = sample_rates[attribute] stacked_lengths = torch.cat(lengths[attribute]).to(self.device) assert stacked_lengths.size(0) == stacked_wavs.size(0) assert len(stacked_sample_rates) == stacked_wavs.size(0) assert len(stacked_texts) == stacked_wavs.size(0) out[attribute] = JointEmbedCondition( text=stacked_texts, wav=stacked_wavs, length=stacked_lengths, sample_rate=stacked_sample_rates, path=stacked_paths, seek_time=stacked_seek_times) return out class ConditionFuser(StreamingModule): """Condition fuser handles the logic to combine the different conditions to the actual model input. Args: fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse each condition. For example: { "prepend": ["description"], "sum": ["genre", "bpm"], "cross": ["description"], } cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. """ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, cross_attention_pos_emb_scale: float = 1.0): super().__init__() assert all( [k in self.FUSING_METHODS for k in fuse2cond.keys()] ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" self.cross_attention_pos_emb = cross_attention_pos_emb self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond self.cond2fuse: tp.Dict[str, str] = {} for fuse_method, conditions in fuse2cond.items(): for condition in conditions: self.cond2fuse[condition] = fuse_method def forward( self, input: torch.Tensor, conditions: tp.Dict[str, ConditionType] ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Fuse the conditions to the provided model input. Args: input (torch.Tensor): Transformer input. conditions (dict[str, ConditionType]): Dict of conditions. Returns: tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input after the conditions have been fused. The second output tensor is the tensor used for cross-attention or None if no cross attention inputs exist. """ B, T, _ = input.shape if 'offsets' in self._streaming_state: first_step = False offsets = self._streaming_state['offsets'] else: first_step = True offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ f"given conditions contain unknown attributes for fuser, " \ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" cross_attention_output = None for cond_type, (cond, cond_mask) in conditions.items(): op = self.cond2fuse[cond_type] if op == 'sum': input += cond elif op == 'input_interpolate': cond = einops.rearrange(cond, "b t d -> b d t") cond = F.interpolate(cond, size=input.shape[1]) input += einops.rearrange(cond, "b d t -> b t d") elif op == 'prepend': if first_step: input = torch.cat([cond, input], dim=1) elif op == 'cross': if cross_attention_output is not None: cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) else: cross_attention_output = cond else: raise ValueError(f"unknown op ({op})") if self.cross_attention_pos_emb and cross_attention_output is not None: positions = torch.arange( cross_attention_output.shape[1], device=cross_attention_output.device ).view(1, -1, 1) pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb if self._is_streaming: self._streaming_state['offsets'] = offsets + T return input, cross_attention_output ================================================ FILE: audiocraft/modules/conv.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import typing as tp import warnings import torch from torch import nn from torch.nn import functional as F from torch.nn.utils import spectral_norm, weight_norm CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', 'time_group_norm']) def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): assert norm in CONV_NORMALIZATIONS if norm == 'weight_norm': return weight_norm(module) elif norm == 'spectral_norm': return spectral_norm(module) else: # We already check was in CONV_NORMALIZATION, so any other choice # doesn't need reparametrization. return module def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): """Return the proper normalization module. If causal is True, this will ensure the returned module is causal, or return an error if the normalization doesn't support causal evaluation. """ assert norm in CONV_NORMALIZATIONS if norm == 'time_group_norm': if causal: raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) return nn.GroupNorm(1, module.out_channels, **norm_kwargs) else: return nn.Identity() def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) -> int: """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) return ideal_length - length def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): """Pad for a convolution to make sure that the last window is full. Extra padding is added at the end. This is required to ensure that we can rebuild an output of the same length, as otherwise, even with padding, some time steps might get removed. For instance, with total padding = 4, kernel size = 4, stride = 2: 0 0 1 2 3 4 5 0 0 # (0s are padding) 1 2 3 # (output frames of a convolution, last 0 is never used) 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 1 2 3 4 # once you removed padding, we are missing one time step ! """ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) return F.pad(x, (0, extra_padding)) def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) if mode == 'reflect': max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: extra_pad = max_pad - length + 1 x = F.pad(x, (0, extra_pad)) padded = F.pad(x, paddings, mode, value) end = padded.shape[-1] - extra_pad return padded[..., :end] else: return F.pad(x, paddings, mode, value) def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): """Remove padding from x, handling properly zero padding. Only for 1d!""" padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right return x[..., padding_left: end] class NormConv1d(nn.Module): """Wrapper around Conv1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ def __init__(self, *args, causal: bool = False, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): super().__init__() self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.conv(x) x = self.norm(x) return x class NormConv2d(nn.Module): """Wrapper around Conv2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): super().__init__() self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.conv(x) x = self.norm(x) return x class NormConvTranspose1d(nn.Module): """Wrapper around ConvTranspose1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ def __init__(self, *args, causal: bool = False, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): super().__init__() self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) self.norm_type = norm def forward(self, x): x = self.convtr(x) x = self.norm(x) return x class NormConvTranspose2d(nn.Module): """Wrapper around ConvTranspose2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): super().__init__() self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) def forward(self, x): x = self.convtr(x) x = self.norm(x) return x class StreamableConv1d(nn.Module): """Conv1d with some builtin handling of asymmetric or causal padding and normalization. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, causal: bool = False, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, pad_mode: str = 'reflect'): super().__init__() # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias, causal=causal, norm=norm, norm_kwargs=norm_kwargs) self.causal = causal self.pad_mode = pad_mode def forward(self, x): B, C, T = x.shape kernel_size = self.conv.conv.kernel_size[0] stride = self.conv.conv.stride[0] dilation = self.conv.conv.dilation[0] kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) if self.causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) return self.conv(x) class StreamableConvTranspose1d(nn.Module): """ConvTranspose1d with some builtin handling of asymmetric or causal padding and normalization. """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, norm: str = 'none', trim_right_ratio: float = 1., norm_kwargs: tp.Dict[str, tp.Any] = {}): super().__init__() self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, causal=causal, norm=norm, norm_kwargs=norm_kwargs) self.causal = causal self.trim_right_ratio = trim_right_ratio assert self.causal or self.trim_right_ratio == 1., \ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. def forward(self, x): kernel_size = self.convtr.convtr.kernel_size[0] stride = self.convtr.convtr.stride[0] padding_total = kernel_size - stride y = self.convtr(x) # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be # removed at the very end, when keeping only the right length for the output, # as removing it here would require also passing the length at the matching layer # in the encoder. if self.causal: # Trim the padding on the right according to the specified ratio # if trim_right_ratio = 1.0, trim everything from right padding_right = math.ceil(padding_total * self.trim_right_ratio) padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) else: # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right y = unpad1d(y, (padding_left, padding_right)) return y ================================================ FILE: audiocraft/modules/diffusion_schedule.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Functions for Noise Schedule, defines diffusion process, reverse process and data processor. """ from collections import namedtuple import random import typing as tp import julius import torch TrainingItem = namedtuple("TrainingItem", "noisy noise step") def betas_from_alpha_bar(alpha_bar): alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) return 1 - alphas class SampleProcessor(torch.nn.Module): def project_sample(self, x: torch.Tensor): """Project the original sample to the 'space' where the diffusion will happen.""" return x def return_sample(self, z: torch.Tensor): """Project back from diffusion space to the actual sample space.""" return z class MultiBandProcessor(SampleProcessor): """ MultiBand sample processor. The input audio is splitted across frequency bands evenly distributed in mel-scale. Each band will be rescaled to match the power distribution of Gaussian noise in that band, using online metrics computed on the first few samples. Args: n_bands (int): Number of mel-bands to split the signal over. sample_rate (int): Sample rate of the audio. num_samples (int): Number of samples to use to fit the rescaling for each band. The processor won't be stable until it has seen that many samples. power_std (float or list/tensor): The rescaling factor computed to match the power of Gaussian noise in each band is taken to that power, i.e. `1.` means full correction of the energy in each band, and values less than `1` means only partial correction. Can be used to balance the relative importance of low vs. high freq in typical audio signals. """ def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): super().__init__() self.n_bands = n_bands self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) self.num_samples = num_samples self.power_std = power_std if isinstance(power_std, list): assert len(power_std) == n_bands power_std = torch.tensor(power_std) self.register_buffer('counts', torch.zeros(1)) self.register_buffer('sum_x', torch.zeros(n_bands)) self.register_buffer('sum_x2', torch.zeros(n_bands)) self.register_buffer('sum_target_x2', torch.zeros(n_bands)) self.counts: torch.Tensor self.sum_x: torch.Tensor self.sum_x2: torch.Tensor self.sum_target_x2: torch.Tensor @property def mean(self): mean = self.sum_x / self.counts return mean @property def std(self): std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() return std @property def target_std(self): target_std = self.sum_target_x2 / self.counts return target_std def project_sample(self, x: torch.Tensor): assert x.dim() == 3 bands = self.split_bands(x) if self.counts.item() < self.num_samples: ref_bands = self.split_bands(torch.randn_like(x)) self.counts += len(x) self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) return bands.sum(dim=0) def return_sample(self, x: torch.Tensor): assert x.dim() == 3 bands = self.split_bands(x) rescale = (self.std / self.target_std) ** self.power_std bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) return bands.sum(dim=0) class NoiseSchedule: """Noise schedule for diffusion. Args: beta_t0 (float): Variance of the first diffusion step. beta_t1 (float): Variance of the last diffusion step. beta_exp (float): Power schedule exponent num_steps (int): Number of diffusion step. variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" clip (float): clipping value for the denoising steps rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) repartition (str): shape of the schedule only power schedule is supported sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution noise_scale (float): Scaling factor for the noise """ def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): self.beta_t0 = beta_t0 self.beta_t1 = beta_t1 self.variance = variance self.num_steps = num_steps self.clip = clip self.sample_processor = sample_processor self.rescale = rescale self.n_bands = n_bands self.noise_scale = noise_scale assert n_bands is None if repartition == "power": self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, device=device, dtype=torch.float) ** beta_exp else: raise RuntimeError('Not implemented') self.rng = random.Random(1234) def get_beta(self, step: tp.Union[int, torch.Tensor]): if self.n_bands is None: return self.betas[step] else: return self.betas[:, step] # [n_bands, len(step)] def get_initial_noise(self, x: torch.Tensor): if self.n_bands is None: return torch.randn_like(x) return torch.randn((x.size(0), self.n_bands, x.size(2))) def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" if step is None: return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands if type(step) is int: return (1 - self.betas[:step + 1]).prod() else: return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: """Create a noisy data item for diffusion model training: Args: x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) tensor_step (bool): If tensor_step = false, only one step t is sample, the whole batch is diffused to the same step and t is int. If tensor_step = true, t is a tensor of size (x.size(0),) every element of the batch is diffused to a independently sampled. """ step: tp.Union[int, torch.Tensor] if tensor_step: bs = x.size(0) step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) else: step = self.rng.randrange(self.num_steps) alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1] x = self.sample_processor.project_sample(x) noise = torch.randn_like(x) noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale return TrainingItem(noisy, noise, step) def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): """Full ddpm reverse process. Args: model (nn.Module): Diffusion model. initial (tensor): Initial Noise. condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). return_list (bool): Whether to return the whole process or only the sampled point. """ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) current = initial iterates = [initial] for step in range(self.num_steps)[::-1]: with torch.no_grad(): estimate = model(current, step, condition=condition).sample alpha = 1 - self.betas[step] previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() previous_alpha_bar = self.get_alpha_bar(step=step - 1) if step == 0: sigma2 = 0 elif self.variance == 'beta': sigma2 = 1 - alpha elif self.variance == 'beta_tilde': sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) elif self.variance == 'none': sigma2 = 0 else: raise ValueError(f'Invalid variance type {self.variance}') if sigma2 > 0: previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale if self.clip: previous = previous.clamp(-self.clip, self.clip) current = previous alpha_bar = previous_alpha_bar if step == 0: previous *= self.rescale if return_list: iterates.append(previous.cpu()) if return_list: return iterates else: return self.sample_processor.return_sample(previous) def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): """Reverse process that only goes through Markov chain states in step_list.""" if step_list is None: step_list = list(range(1000))[::-50] + [0] alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) current = initial * self.noise_scale iterates = [current] for idx, step in enumerate(step_list[:-1]): with torch.no_grad(): estimate = model(current, step, condition=condition).sample * self.noise_scale alpha = 1 - betas_subsampled[-1 - idx] previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) if step == step_list[-2]: sigma2 = 0 previous_alpha_bar = torch.tensor(1.0) else: sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) if sigma2 > 0: previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale if self.clip: previous = previous.clamp(-self.clip, self.clip) current = previous alpha_bar = previous_alpha_bar if step == 0: previous *= self.rescale if return_list: iterates.append(previous.cpu()) if return_list: return iterates else: return self.sample_processor.return_sample(previous) ================================================ FILE: audiocraft/modules/lstm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from torch import nn class StreamableLSTM(nn.Module): """LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout. """ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): super().__init__() self.skip = skip self.lstm = nn.LSTM(dimension, dimension, num_layers) def forward(self, x): x = x.permute(2, 0, 1) y, _ = self.lstm(x) if self.skip: y = y + x y = y.permute(1, 2, 0) return y ================================================ FILE: audiocraft/modules/rope.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp from torch import nn import torch class XPos(nn.Module): """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1). This applies an exponential decay to the RoPE rotation matrix. Args: dim (int): Embedding dimension. smoothing (float): Smoothing factor applied to the decay rates. base_scale (int): Base decay rate, given in terms of scaling time. device (torch.device, optional): Device on which to initialize the module. dtype (torch.dtype): dtype to use to generate the embedding. """ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, device=None, dtype: torch.dtype = torch.float32): super().__init__() assert dim % 2 == 0 assert dtype in [torch.float64, torch.float32] self.dtype = dtype self.base_scale = base_scale half_dim = dim // 2 adim = torch.arange(half_dim, device=device, dtype=dtype) decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing) self.register_buffer("decay_rates", decay_rates) self.decay: tp.Optional[torch.Tensor] = None def get_decay(self, start: int, end: int): """Create complex decay tensor, cache values for fast computation.""" if self.decay is None or end > self.decay.shape[0]: assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) power = idx / self.base_scale scale = self.decay_rates ** power.unsqueeze(-1) self.decay = torch.polar(scale, torch.zeros_like(scale)) return self.decay[start:end] # [T, C/2] class RotaryEmbedding(nn.Module): """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). Args: dim (int): Embedding dimension (twice the number of frequencies). max_period (float): Maximum period of the rotation frequencies. xpos (bool): Use xPos, applies an exponential decay to rotation matrix. scale (float): Scale of positional embedding, set to 0 to deactivate. device (torch.device, optional): Device on which to initialize the module. dtype (torch.dtype): dtype to use to generate the embedding. """ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32): super().__init__() assert dim % 2 == 0 self.scale = scale assert dtype in [torch.float64, torch.float32] self.dtype = dtype adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)] frequencies = 1.0 / (max_period ** (adim / dim)) self.register_buffer("frequencies", frequencies) self.rotation: tp.Optional[torch.Tensor] = None self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None def get_rotation(self, start: int, end: int): """Create complex rotation tensor, cache values for fast computation.""" if self.rotation is None or end > self.rotation.shape[0]: assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) angles = torch.outer(idx, self.frequencies) self.rotation = torch.polar(torch.ones_like(angles), angles) return self.rotation[start:end] def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): """Apply rope rotation to query or key tensor.""" T = x.shape[1] rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) if self.xpos: decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) else: decay = 1.0 if invert_decay: decay = decay ** -1 x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) return x_out.type_as(x) def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): """ Apply rope rotation to both query and key tensors. Supports streaming mode, in which query and key are not expected to have the same shape. In streaming mode, key will be of length [P + C] with P the cached past timesteps, but query will be [C] (typically C == 1). Args: query (torch.Tensor): Query to rotate. key (torch.Tensor): Key to rotate. start (int): Start index of the sequence for time offset. """ query_timesteps = query.shape[1] key_timesteps = key.shape[1] streaming_offset = key_timesteps - query_timesteps query_out = self.rotate(query, start + streaming_offset) key_out = self.rotate(key, start, invert_decay=True) return query_out, key_out ================================================ FILE: audiocraft/modules/seanet.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import numpy as np import torch.nn as nn from .conv import StreamableConv1d, StreamableConvTranspose1d from .lstm import StreamableLSTM class SEANetResnetBlock(nn.Module): """Residual block from SEANet model. Args: dim (int): Dimension of the input/output. kernel_sizes (list): List of kernel sizes for the convolutions. dilations (list): List of dilations for the convolutions. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. norm (str): Normalization method. norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. causal (bool): Whether to use fully causal convolution. pad_mode (str): Padding mode for the convolutions. compress (int): Reduced dimensionality in residual branches (from Demucs v3). true_skip (bool): Whether to use true skip connection or a simple (streamable) convolution as the skip connection. """ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): super().__init__() assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' act = getattr(nn, activation) hidden = dim // compress block = [] for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): in_chs = dim if i == 0 else hidden out_chs = dim if i == len(kernel_sizes) - 1 else hidden block += [ act(**activation_params), StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, norm=norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode), ] self.block = nn.Sequential(*block) self.shortcut: nn.Module if true_skip: self.shortcut = nn.Identity() else: self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) def forward(self, x): return self.shortcut(x) + self.block(x) class SEANetEncoder(nn.Module): """SEANet encoder. Args: channels (int): Audio channels. dimension (int): Intermediate representation dimension. n_filters (int): Base width for the model. n_residual_layers (int): nb of residual layers. ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here that must match the decoder order. We use the decoder order as some models may only employ the decoder. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. norm (str): Normalization method. norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. kernel_size (int): Kernel size for the initial convolution. last_kernel_size (int): Kernel size for the initial convolution. residual_kernel_size (int): Kernel size for the residual layers. dilation_base (int): How much to increase the dilation with each layer. causal (bool): Whether to use fully causal convolution. pad_mode (str): Padding mode for the convolutions. true_skip (bool): Whether to use true skip connection or a simple (streamable) convolution as the skip connection in the residual network blocks. compress (int): Reduced dimensionality in residual branches (from Demucs v3). lstm (int): Number of LSTM layers at the end of the encoder. disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. For the encoder, it corresponds to the N first blocks. """ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0): super().__init__() self.channels = channels self.dimension = dimension self.n_filters = n_filters self.ratios = list(reversed(ratios)) del ratios self.n_residual_layers = n_residual_layers self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ "Number of blocks for which to disable norm is invalid." \ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." act = getattr(nn, activation) mult = 1 model: tp.List[nn.Module] = [ StreamableConv1d(channels, mult * n_filters, kernel_size, norm='none' if self.disable_norm_outer_blocks >= 1 else norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) ] # Downsample to raw audio scale for i, ratio in enumerate(self.ratios): block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm # Add residual layers for j in range(n_residual_layers): model += [ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], dilations=[dilation_base ** j, 1], norm=block_norm, norm_params=norm_params, activation=activation, activation_params=activation_params, causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] # Add downsampling layers model += [ act(**activation_params), StreamableConv1d(mult * n_filters, mult * n_filters * 2, kernel_size=ratio * 2, stride=ratio, norm=block_norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode), ] mult *= 2 if lstm: model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] model += [ act(**activation_params), StreamableConv1d(mult * n_filters, dimension, last_kernel_size, norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) ] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class SEANetDecoder(nn.Module): """SEANet decoder. Args: channels (int): Audio channels. dimension (int): Intermediate representation dimension. n_filters (int): Base width for the model. n_residual_layers (int): nb of residual layers. ratios (Sequence[int]): kernel size and stride ratios. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. final_activation (str): Final activation function after all convolutions. final_activation_params (dict): Parameters to provide to the activation function. norm (str): Normalization method. norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. kernel_size (int): Kernel size for the initial convolution. last_kernel_size (int): Kernel size for the initial convolution. residual_kernel_size (int): Kernel size for the residual layers. dilation_base (int): How much to increase the dilation with each layer. causal (bool): Whether to use fully causal convolution. pad_mode (str): Padding mode for the convolutions. true_skip (bool): Whether to use true skip connection or a simple. (streamable) convolution as the skip connection in the residual network blocks. compress (int): Reduced dimensionality in residual branches (from Demucs v3). lstm (int): Number of LSTM layers at the end of the encoder. disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. For the decoder, it corresponds to the N last blocks. trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. If equal to 1.0, it means that all the trimming is done at the right. """ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): super().__init__() self.dimension = dimension self.channels = channels self.n_filters = n_filters self.ratios = ratios del ratios self.n_residual_layers = n_residual_layers self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ "Number of blocks for which to disable norm is invalid." \ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." act = getattr(nn, activation) mult = int(2 ** len(self.ratios)) model: tp.List[nn.Module] = [ StreamableConv1d(dimension, mult * n_filters, kernel_size, norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) ] if lstm: model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] # Upsample to raw audio scale for i, ratio in enumerate(self.ratios): block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm # Add upsampling layers model += [ act(**activation_params), StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, kernel_size=ratio * 2, stride=ratio, norm=block_norm, norm_kwargs=norm_params, causal=causal, trim_right_ratio=trim_right_ratio), ] # Add residual layers for j in range(n_residual_layers): model += [ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], dilations=[dilation_base ** j, 1], activation=activation, activation_params=activation_params, norm=block_norm, norm_params=norm_params, causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] mult //= 2 # Add final layers model += [ act(**activation_params), StreamableConv1d(n_filters, channels, last_kernel_size, norm='none' if self.disable_norm_outer_blocks >= 1 else norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) ] # Add optional final activation to decoder (eg. tanh) if final_activation is not None: final_act = getattr(nn, final_activation) final_activation_params = final_activation_params or {} model += [ final_act(**final_activation_params) ] self.model = nn.Sequential(*model) def forward(self, z): y = self.model(z) return y ================================================ FILE: audiocraft/modules/streaming.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Streaming module API that should be implemented by all Streaming components, """ from contextlib import contextmanager import typing as tp from torch import nn import torch State = tp.Dict[str, torch.Tensor] class StreamingModule(nn.Module): """Common API for streaming components. Each streaming component has a streaming state, which is just a dict[str, Tensor]. By convention, the first dim of each tensor must be the batch size. Don't use dots in the key names, as this would clash with submodules (like in state_dict). If `self._is_streaming` is True, the component should use and remember the proper state inside `self._streaming_state`. To set a streaming component in streaming state, use with module.streaming(): ... This will automatically reset the streaming state when exiting the context manager. This also automatically propagates to all streaming children module. Some module might also implement the `StreamingModule.flush` method, although this one is trickier, as all parents module must be StreamingModule and implement it as well for it to work properly. See `StreamingSequential` after. """ def __init__(self) -> None: super().__init__() self._streaming_state: State = {} self._is_streaming = False def _apply_named_streaming(self, fn: tp.Any): for name, module in self.named_modules(): if isinstance(module, StreamingModule): fn(name, module) def _set_streaming(self, streaming: bool): def _set_streaming(name, module): module._is_streaming = streaming self._apply_named_streaming(_set_streaming) @contextmanager def streaming(self): """Context manager to enter streaming mode. Reset streaming state on exit.""" self._set_streaming(True) try: yield finally: self._set_streaming(False) self.reset_streaming() def reset_streaming(self): """Reset the streaming state.""" def _reset(name: str, module: StreamingModule): module._streaming_state.clear() self._apply_named_streaming(_reset) def get_streaming_state(self) -> State: """Return the streaming state, including that of sub-modules.""" state: State = {} def _add(name: str, module: StreamingModule): if name: name += "." for key, value in module._streaming_state.items(): state[name + key] = value self._apply_named_streaming(_add) return state def set_streaming_state(self, state: State): """Set the streaming state, including that of sub-modules.""" state = dict(state) def _set(name: str, module: StreamingModule): if name: name += "." module._streaming_state.clear() for key, value in list(state.items()): # complexity is not ideal here, but probably fine. if key.startswith(name): local_key = key[len(name):] if '.' not in local_key: module._streaming_state[local_key] = value del state[key] self._apply_named_streaming(_set) assert len(state) == 0, list(state.keys()) def flush(self, x: tp.Optional[torch.Tensor] = None): """Flush any remaining outputs that were waiting for completion. Typically, for convolutions, this will add the final padding and process the last buffer. This should take an optional argument `x`, which will be provided if a module before this one in the streaming pipeline has already spitted out a flushed out buffer. """ if x is None: return None else: return self(x) class StreamingSequential(StreamingModule, nn.Sequential): """A streaming compatible alternative of `nn.Sequential`. """ def flush(self, x: tp.Optional[torch.Tensor] = None): for module in self: if isinstance(module, StreamingModule): x = module.flush(x) elif x is not None: x = module(x) return x ================================================ FILE: audiocraft/modules/transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Transformer model, with streaming support, xformer attention support and easy causal attention with a potentially finite receptive field. See `StreamingTransformer` for more information. Unlike regular PyTorch Transformer, we make the hard choice that batches are first. """ import typing as tp from einops import rearrange import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint as torch_checkpoint from xformers import ops from .rope import RotaryEmbedding from .streaming import StreamingModule _efficient_attention_backend: str = 'torch' def set_efficient_attention_backend(backend: str = 'torch'): # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster). global _efficient_attention_backend assert _efficient_attention_backend in ['xformers', 'torch'] _efficient_attention_backend = backend def _get_attention_time_dimension() -> int: if _efficient_attention_backend == 'torch': return 2 else: return 1 def _is_profiled() -> bool: # Return true if we are currently running with a xformers profiler activated. try: from xformers.profiler import profiler except ImportError: return False return profiler._Profiler._CURRENT_PROFILER is not None def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: """Create normalization module for transformer encoder layer. Args: norm_type (str): Normalization method. dim (int): Dimension of the normalized layer. **kwargs (dict): Additional parameters for normalization layer. Returns: nn.Module: Normalization module. """ if norm_type == 'layer_norm': return nn.LayerNorm(dim, eps=1e-5, **kwargs) else: raise ValueError(f"Unknown norm type: {norm_type}") def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Create sinusoidal positional embedding, with shape `[B, T, C]`. Args: positions (torch.Tensor): LongTensor of positions. dim (int): Dimension of the embedding. max_period (float): Maximum period of the cosine/sine functions. dtype (torch.dtype or str): dtype to use to generate the embedding. Returns: torch.Tensor: Sinusoidal positional embedding. """ # We aim for BTC format assert dim % 2 == 0 half_dim = dim // 2 positions = positions.to(dtype) adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" if n_rep == 1: return x if _efficient_attention_backend == 'torch': bs, n_kv_heads, slen, head_dim = x.shape return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) .reshape(bs, n_kv_heads * n_rep, slen, head_dim) ) else: bs, slen, n_kv_heads, head_dim = x.shape return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class LayerScale(nn.Module): """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). This rescales diagonally the residual outputs close to 0, with a learnt scale. Args: channels (int): Number of channels. init (float): Initial scale. channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. device (torch.device or str, optional): Device on which to initialize the module. dtype (torch.dtype, optional): dtype to use to initialize the module. """ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, device=None, dtype=None): super().__init__() self.channel_last = channel_last self.scale = nn.Parameter( torch.full((channels,), init, requires_grad=True, device=device, dtype=dtype)) def forward(self, x: torch.Tensor): if self.channel_last: return self.scale * x else: return self.scale[:, None] * x class StreamingMultiheadAttention(StreamingModule): """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. Args: embed_dim (int): Dimension to project to. num_heads (int): Number of heads. dropout (float): Dropout level. bias (bool): Use bias in projections. causal (bool): Causal mask applied automatically. past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). rope (`RotaryEmbedding`, optional): Rope embedding to use. cross_attention: Should be true when used as a cross attention. All keys and values must be available at once, streaming is only for the queries. Cannot be used with `causal` or `rope` (as it wouldn't make sens to interpret the time steps in the keys relative to those in the queries). safe_streaming (bool): Bug fix, will go away with xformers update. qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). This will lead to faster decoding time on A100 or other GPUs with tensorcore. device (torch.device, optional): Device on which to initialize. dtype (torch.dtype, optional): dtype to use. """ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, device=None, dtype=None): super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} if past_context is not None: assert causal self.embed_dim = embed_dim self.causal = causal self.past_context = past_context self.memory_efficient = memory_efficient self.attention_as_float32 = attention_as_float32 self.rope = rope self.cross_attention = cross_attention self.safe_streaming = safe_streaming self.num_heads = num_heads self.dropout = dropout self.kv_repeat = kv_repeat if cross_attention: assert not causal, "Causal cannot work with cross attention." assert rope is None, "Rope cannot work with cross attention." if memory_efficient: _verify_xformers_memory_efficient_compat() self.custom = _is_custom(custom, memory_efficient) if self.custom: out_dim = embed_dim assert num_heads % kv_repeat == 0 assert not cross_attention or kv_repeat == 1 num_kv = num_heads // kv_repeat kv_dim = (embed_dim // num_heads) * num_kv out_dim += 2 * kv_dim in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) # We try to follow the default PyTorch MHA convention, to easily compare results. self.in_proj_weight = in_proj.weight self.in_proj_bias = in_proj.bias if bias: self.in_proj_bias.data.zero_() # Following Pytorch convention self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) if bias: self.out_proj.bias.data.zero_() else: assert not qk_layer_norm assert kv_repeat == 1 self.mha = nn.MultiheadAttention( embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, **factory_kwargs) self.qk_layer_norm = qk_layer_norm if qk_layer_norm: assert self.custom assert kv_repeat == 1 ln_dim = embed_dim self.q_layer_norm = nn.LayerNorm(ln_dim) self.k_layer_norm = nn.LayerNorm(ln_dim) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): if not self.custom: # Support compat with regular MHA keys = [n for n, _ in self.mha.named_parameters()] for key in keys: if prefix + key in state_dict: state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): # Return a causal mask, accounting for potentially stored past keys/values # We actually return a bias for the attention score, as this has the same # convention both in the builtin MHA in Pytorch, and Xformers functions. time_dim = _get_attention_time_dimension() if self.memory_efficient: from xformers.ops import LowerTriangularMask if current_steps == 1: # If we only have one step, then we do not need a mask. return None elif 'past_keys' in self._streaming_state: raise RuntimeError("Not supported at the moment") else: # Then we can safely use a lower triangular mask return LowerTriangularMask() if self._streaming_state: past_keys = self._streaming_state['past_keys'] past_steps = past_keys.shape[time_dim] else: past_steps = 0 queries_pos = torch.arange( past_steps, current_steps + past_steps, device=device).view(-1, 1) keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) delta = queries_pos - keys_pos valid = delta >= 0 if self.past_context is not None: valid &= (delta <= self.past_context) return torch.where( valid, torch.zeros([], device=device, dtype=dtype), torch.full([], float('-inf'), device=device, dtype=dtype)) def _complete_kv(self, k, v): time_dim = _get_attention_time_dimension() if self.cross_attention: # With cross attention we assume all keys and values # are already available, and streaming is with respect # to the queries only. return k, v # Complete the key/value pair using the streaming state. if self._streaming_state: pk = self._streaming_state['past_keys'] nk = torch.cat([pk, k], dim=time_dim) if v is k: nv = nk else: pv = self._streaming_state['past_values'] nv = torch.cat([pv, v], dim=time_dim) else: nk = k nv = v assert nk.shape[time_dim] == nv.shape[time_dim] offset = 0 if self.past_context is not None: offset = max(0, nk.shape[time_dim] - self.past_context) if self._is_streaming: self._streaming_state['past_keys'] = nk[:, offset:] if v is not k: self._streaming_state['past_values'] = nv[:, offset:] if 'offset' in self._streaming_state: self._streaming_state['offset'] += offset else: self._streaming_state['offset'] = torch.tensor(0) return nk, nv def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): # TODO: fix and verify layout. assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn." # Apply rope embeddings to query and key tensors. assert self.rope is not None if 'past_keys' in self._streaming_state: past_keys_offset = self._streaming_state['past_keys'].shape[1] else: past_keys_offset = 0 if 'offset' in self._streaming_state: past_context_offset = int(self._streaming_state['offset'].item()) else: past_context_offset = 0 streaming_offset = past_context_offset + past_keys_offset return self.rope.rotate_qk(query, key, start=streaming_offset) def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False): assert attn_mask is None assert not is_causal, ("New param added in torch 2.0.1 not supported, " "use the causal args in the constructor.") time_dim = _get_attention_time_dimension() if time_dim == 2: layout = "b h t d" else: layout = "b t h d" dtype = query.dtype if self._is_streaming: assert self.causal or self.cross_attention, \ "Streaming only available for causal or cross attention" if self.causal: # At the moment we specialize only for the self-attention case. assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) if self.custom: # custom implementation assert need_weights is False assert key_padding_mask is None if self.cross_attention: # Different queries, keys, values, we have to spit manually the weights # before applying the linear. dim = self.in_proj_weight.shape[0] // 3 if self.in_proj_bias is None: bias_q, bias_k, bias_v = None, None, None else: bias_q = self.in_proj_bias[:dim] bias_k = self.in_proj_bias[dim: 2 * dim] bias_v = self.in_proj_bias[2 * dim:] q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) # todo: when streaming, we could actually save k, v and check the shape actually match. k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) if self.qk_layer_norm is True: q = self.q_layer_norm(q) k = self.k_layer_norm(k) q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] else: if not _is_profiled(): # profiling breaks that propertysomehow. assert query is key, "specialized implementation" assert value is key, "specialized implementation" projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) if self.kv_repeat == 1: if time_dim == 2: bound_layout = "b h p t d" else: bound_layout = "b t p h d" packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) q, k, v = ops.unbind(packed, dim=2) else: embed_dim = self.embed_dim per_head_dim = (embed_dim // self.num_heads) kv_heads = self.num_heads // self.kv_repeat q = projected[:, :, :embed_dim] start = embed_dim end = start + per_head_dim * kv_heads k = projected[:, :, start: end] v = projected[:, :, end:] q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads) k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads) v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads) if self.qk_layer_norm is True: assert self.kv_repeat == 1 q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]] q = self.q_layer_norm(q) k = self.k_layer_norm(k) q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]] if self.rope: q, k = self._apply_rope(q, k) k, v = self._complete_kv(k, v) if self.kv_repeat > 1: k = expand_repeated_kv(k, self.kv_repeat) v = expand_repeated_kv(v, self.kv_repeat) if self.attention_as_float32: q, k, v = [x.float() for x in [q, k, v]] if self.memory_efficient: p = self.dropout if self.training else 0 if _efficient_attention_backend == 'torch': x = torch.nn.functional.scaled_dot_product_attention( q, k, v, is_causal=attn_mask is not None, dropout_p=p) else: x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) else: # We include the dot product as float32, for consistency # with the other implementations that include that step # as part of the attention. Note that when using `autocast`, # the einsums would be done as bfloat16, but the softmax # would be done as bfloat16, so `attention_as_float32` will # extend a bit the range of operations done in float32, # although this should make no difference. q = q / q.shape[-1] ** 0.5 key_layout = layout.replace('t', 'k') query_layout = layout if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': with torch.autocast(device_type=q.device.type, dtype=torch.float32): pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) else: pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) if attn_mask is not None: pre_w = pre_w + attn_mask w = torch.softmax(pre_w, dim=-1) w = F.dropout(w, self.dropout, training=self.training).to(v) # Key and value have the same format. x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v) x = x.to(dtype) x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) x = self.out_proj(x) else: key, value = self._complete_kv(key, value) if self.attention_as_float32: query, key, value = [x.float() for x in [query, key, value]] x, _ = self.mha( query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights) x = x.to(dtype) return x, None class StreamingTransformerLayer(nn.TransformerEncoderLayer): """TransformerLayer with Streaming / Causal support. This also integrates cross_attention, when passing `cross_attention=True`, rather than having two separate classes like in PyTorch. Args: d_model (int): Dimension of the data. num_heads (int): Number of heads. dim_feedforward (int): Intermediate dimension of FF module. dropout (float): Dropout both for MHA and FF. bias_ff (bool): Use bias for FF. bias_attn (bool): Use bias for MHA. causal (bool): Causal mask applied automatically. past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. qk_layer_norm_cross (bool): Same for the cross attention. cross_attention (bool): If True, expect to get secondary input for cross-attention. Cross attention will use the default MHA, as it typically won't require special treatment. layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. rope (`RotaryEmbedding`, optional): Rope embedding to use. attention_dropout (float, optional): If not None, separate the value of the dimension dropout in FFN and of the attention dropout. kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). This will lead to faster decoding time on A100 or other GPUs with tensorcore. device (torch.device, optional): Device on which to initialize. dtype (torch.dtype, optional): dtype to use. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, cross_attention: bool = False, layer_scale: tp.Optional[float] = None, rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): super().__init__(d_model, num_heads, dim_feedforward, dropout, device=device, dtype=dtype, batch_first=True, **kwargs) factory_kwargs = {'device': device, 'dtype': dtype} # Redefine self_attn to our streaming multi-head attention attn_kwargs: tp.Dict[str, tp.Any] = { 'embed_dim': d_model, 'num_heads': num_heads, 'dropout': dropout if attention_dropout is None else attention_dropout, 'bias': bias_attn, 'custom': custom, 'memory_efficient': memory_efficient, 'attention_as_float32': attention_as_float32, } self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore # Redefine feedforward layers to expose bias parameter self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) self.layer_scale_1: nn.Module self.layer_scale_2: nn.Module if layer_scale is None: self.layer_scale_1 = nn.Identity() self.layer_scale_2 = nn.Identity() else: self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) self.cross_attention: tp.Optional[nn.Module] = None if cross_attention: self.cross_attention = StreamingMultiheadAttention( cross_attention=True, qk_layer_norm=qk_layer_norm_cross, **attn_kwargs, **factory_kwargs) # Norm and dropout self.dropout_cross = nn.Dropout(dropout) # eps value matching that used in PyTorch reference implementation. self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) self.layer_scale_cross: nn.Module if layer_scale is None: self.layer_scale_cross = nn.Identity() else: self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore def _cross_attention_block(self, src: torch.Tensor, cross_attention_src: torch.Tensor) -> torch.Tensor: assert self.cross_attention is not None # queries are from src, keys and values from cross_attention_src. x = self.cross_attention( src, cross_attention_src, cross_attention_src, need_weights=False)[0] return self.dropout_cross(x) # type: ignore def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore src_key_padding_mask: tp.Optional[torch.Tensor] = None, cross_attention_src: tp.Optional[torch.Tensor] = None): if self.cross_attention is None: assert cross_attention_src is None else: assert cross_attention_src is not None x = src if self.norm_first: x = x + self.layer_scale_1( self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) if cross_attention_src is not None: x = x + self.layer_scale_cross( self._cross_attention_block( self.norm_cross(x), cross_attention_src)) x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) else: x = self.norm1(x + self.layer_scale_1( self._sa_block(x, src_mask, src_key_padding_mask))) if cross_attention_src is not None: x = self.norm_cross( x + self.layer_scale_cross( self._cross_attention_block(src, cross_attention_src))) x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) return x class StreamingTransformer(StreamingModule): """Transformer with Streaming / Causal support. Args: d_model (int): Dimension of the data. num_heads (int): Number of heads. dim_feedforward (int): Intermediate dimension of FF module. dropout (float): Dropout both for MHA and FF. bias_ff (bool): Use bias for FF. bias_attn (bool): Use bias for MHA. causal (bool): Causal mask applied automatically. past_context (int, optional): Receptive field for the causal mask, infinite if None. custom (bool): Use custom MHA implementation, for testing / benchmarking. memory_efficient (bool): Use xformers based memory efficient attention. attention_as_float32 (bool): Perform the attention as float32 (especially important with memory_efficient as autocast won't do this automatically). cross_attention (bool): If True, expect to get secondary input for cross-attention. layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale. positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). max_period (float): Maximum period of the time embedding. positional_scale (float): Scale of positional embedding, set to 0 to deactivate. xpos (bool): Apply xpos exponential decay to positional embedding (rope only). lr (float, optional): learning rate override through the `make_optim_group` API. weight_decay (float, optional): Weight_decay override through the `make_optim_group` API. layer_class: (subclass of `StreamingTransformerLayer): class to use to initialize the layers, allowing further customization outside of AudioCraft. checkpointing (str): Checkpointing strategy to reduce memory usage. No checkpointing if set to 'none'. Per layer checkpointing using PyTorch if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, minimal memory usage, but maximal runtime). Finally, `xformers_default` provide a policy for opting-out some operations of the checkpointing like linear layers and attention, providing a middle ground between speed and memory. device (torch.device, optional): Device on which to initialize. dtype (torch.dtype, optional): dtype to use. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, cross_attention: bool = False, layer_scale: tp.Optional[float] = None, positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, checkpointing: str = 'none', device=None, dtype=None, **kwargs): super().__init__() assert d_model % num_heads == 0 self.positional_embedding = positional_embedding self.max_period = max_period self.positional_scale = positional_scale self.weight_decay = weight_decay self.lr = lr assert positional_embedding in ['sin', 'rope', 'sin_rope'] self.rope: tp.Optional[RotaryEmbedding] = None if self.positional_embedding in ['rope', 'sin_rope']: assert _is_custom(custom, memory_efficient) self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, xpos=xpos, scale=positional_scale, device=device) self.checkpointing = checkpointing assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] if self.checkpointing.startswith('xformers'): _verify_xformers_internal_compat() self.layers = nn.ModuleList() for idx in range(num_layers): self.layers.append( layer_class( d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, causal=causal, past_context=past_context, custom=custom, memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, device=device, dtype=dtype, **kwargs)) if self.checkpointing != 'none': for layer in self.layers: # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the # backward hook inside of FSDP... layer._magma_checkpointed = True # type: ignore assert layer.layer_drop == 0., "Need further checking" # type: ignore def _apply_layer(self, layer, *args, **kwargs): method = self.checkpointing if method == 'none': return layer(*args, **kwargs) elif method == 'torch': return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) elif method.startswith('xformers'): from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy if method == 'xformers_default': # those operations will be saved, and not recomputed. # According to Francisco we can get smarter policies but this is a good start. allow_list = [ "xformers.efficient_attention_forward_cutlass.default", "xformers_flash.flash_fwd.default", "aten.addmm.default", "aten.mm.default", ] elif method == 'xformers_mm': # those operations will be saved, and not recomputed. # According to Francisco we can get smarter policies but this is a good start. allow_list = [ "aten.addmm.default", "aten.mm.default", ] else: raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") policy_fn = _get_default_policy(allow_list) return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) else: raise ValueError(f"Checkpointing method {method} is unknown.") def forward(self, x: torch.Tensor, *args, **kwargs): B, T, C = x.shape if 'offsets' in self._streaming_state: offsets = self._streaming_state['offsets'] else: offsets = torch.zeros(B, dtype=torch.long, device=x.device) if self.positional_embedding in ['sin', 'sin_rope']: positions = torch.arange(T, device=x.device).view(1, -1, 1) positions = positions + offsets.view(-1, 1, 1) pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) x = x + self.positional_scale * pos_emb for layer in self.layers: x = self._apply_layer(layer, x, *args, **kwargs) if self._is_streaming: self._streaming_state['offsets'] = offsets + T return x def make_optim_group(self): group = {"params": list(self.parameters())} if self.lr is not None: group["lr"] = self.lr if self.weight_decay is not None: group["weight_decay"] = self.weight_decay return group # special attention related function def _verify_xformers_memory_efficient_compat(): try: from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa except ImportError: raise ImportError( "xformers is not installed. Please install it and try again.\n" "To install on AWS and Azure, run \n" "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" "To install on FAIR Cluster, run \n" "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") def _verify_xformers_internal_compat(): try: from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa except ImportError: raise ImportError( "Francisco's fairinternal xformers is not installed. Please install it and try again.\n" "To install on AWS and Azure, run \n" "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" "To install on FAIR Cluster, run \n" "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") def _is_custom(custom: bool, memory_efficient: bool): return custom or memory_efficient ================================================ FILE: audiocraft/optim/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers and Exponential Moving Average. """ # flake8: noqa from .cosine_lr_scheduler import CosineLRScheduler from .dadam import DAdaptAdam from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler from .ema import ModuleDictEMA ================================================ FILE: audiocraft/optim/cosine_lr_scheduler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler class CosineLRScheduler(_LRScheduler): """Cosine LR scheduler. Args: optimizer (Optimizer): Torch optimizer. warmup_steps (int): Number of warmup steps. total_steps (int): Total number of steps. lr_min_ratio (float): Minimum learning rate. cycle_length (float): Cycle length. """ def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, lr_min_ratio: float = 0.0, cycle_length: float = 1.0): self.warmup_steps = warmup_steps assert self.warmup_steps >= 0 self.total_steps = total_steps assert self.total_steps >= 0 self.lr_min_ratio = lr_min_ratio self.cycle_length = cycle_length super().__init__(optimizer) def _get_sched_lr(self, lr: float, step: int): if step < self.warmup_steps: lr_ratio = step / self.warmup_steps lr = lr_ratio * lr elif step <= self.total_steps: s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ (1. + math.cos(math.pi * s / self.cycle_length)) lr = lr_ratio * lr else: lr_ratio = self.lr_min_ratio lr = lr_ratio * lr return lr def get_lr(self): return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] ================================================ FILE: audiocraft/optim/dadam.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging from typing import TYPE_CHECKING, Any import torch import torch.optim import torch.distributed as dist if TYPE_CHECKING: from torch.optim.optimizer import _params_t else: _params_t = Any logger = logging.getLogger(__name__) def to_real(x): if torch.is_complex(x): return x.real else: return x class DAdaptAdam(torch.optim.Optimizer): """Adam with D-Adaptation automatic step-sizes. Leave LR set to 1 unless you encounter instability. Args: params (iterable): Iterable of parameters to optimize or dicts defining parameter groups. lr (float): Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. betas (tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) momentum (float): Momentum value in the range [0,1) (default: 0.9). eps (float): Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). weight_decay (float): Weight decay, i.e. a L2 penalty (default: 0). log_every (int): Log using print every k steps, default 0 (no logging). decouple (boolean): Use AdamW style decoupled weight decay d0 (float): Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. growth_rate (float): prevent the D estimate from growing faster than this multiplicative rate. Default is inf, for unrestricted. Values like 1.02 give a kind of learning rate warmup effect. fsdp_in_use (bool): If you're using sharded parameters, this should be set to True. The optimizer will attempt to auto-detect this, but if you're using an implementation other than PyTorch's builtin version, the auto-detection won't work. """ def __init__(self, params, lr=1.0, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, log_every=0, decouple=True, d0=1e-6, growth_rate=float('inf')): if not 0.0 < d0: raise ValueError("Invalid d0 value: {}".format(d0)) if not 0.0 < lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 < eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if decouple: logger.info("Using decoupled weight decay") from .fsdp import is_fsdp_used fsdp_in_use = is_fsdp_used() defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, d=d0, k=0, gsq_weighted=0.0, log_every=log_every, decouple=decouple, growth_rate=growth_rate, fsdp_in_use=fsdp_in_use) super().__init__(params, defaults) @property def supports_memory_efficient_fp16(self): return False @property def supports_flat_params(self): return True def step(self, closure=None): """Performs a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() g_sq = 0.0 sksq_weighted = 0.0 sk_l1 = 0.0 lr = max(group['lr'] for group in self.param_groups) group = self.param_groups[0] gsq_weighted = group['gsq_weighted'] d = group['d'] dlr = d*lr growth_rate = group['growth_rate'] decouple = group['decouple'] fsdp_in_use = group['fsdp_in_use'] log_every = group['log_every'] beta1, beta2 = group['betas'] for group in self.param_groups: group_lr = group['lr'] decay = group['weight_decay'] k = group['k'] eps = group['eps'] if group_lr not in [lr, 0.0]: raise RuntimeError("Setting different lr values in different parameter " "groups is only supported for values of 0") for p in group['params']: if p.grad is None: continue if hasattr(p, "_fsdp_flattened"): fsdp_in_use = True grad = p.grad.data # Apply weight decay (coupled variant) if decay != 0 and not decouple: grad.add_(p.data, alpha=decay) state = self.state[p] # State initialization if 'step' not in state: state['step'] = 0 state['s'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format).detach() # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( to_real(p.data), memory_format=torch.preserve_format).detach() exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] grad_grad = to_real(grad * grad.conj()) # Adam EMA updates if group_lr > 0: exp_avg.mul_(beta1).add_(grad, alpha=dlr*(1-beta1)) exp_avg_sq.mul_(beta2).add_(grad_grad, alpha=1-beta2) denom = exp_avg_sq.sqrt().add_(eps) g_sq += grad_grad.div_(denom).sum().item() s = state['s'] s.mul_(beta2).add_(grad, alpha=dlr*(1-beta2)) sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item() sk_l1 += s.abs().sum().item() ###### gsq_weighted = beta2*gsq_weighted + g_sq*(dlr**2)*(1-beta2) d_hat = d # if we have not done any progres, return # if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0) if sk_l1 == 0: return loss if lr > 0.0: if fsdp_in_use: dist_tensor = torch.zeros(3, device='cuda') dist_tensor[0] = sksq_weighted dist_tensor[1] = gsq_weighted dist_tensor[2] = sk_l1 dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) global_sksq_weighted = dist_tensor[0] global_gsq_weighted = dist_tensor[1] global_sk_l1 = dist_tensor[2] else: global_sksq_weighted = sksq_weighted global_gsq_weighted = gsq_weighted global_sk_l1 = sk_l1 d_hat = (global_sksq_weighted/(1-beta2) - global_gsq_weighted)/global_sk_l1 d = max(d, min(d_hat, d*growth_rate)) if log_every > 0 and k % log_every == 0: logger.info( f"(k={k}) dlr: {dlr:1.1e} d_hat: {d_hat:1.1e}, d: {d:1.8}. " f"sksq_weighted={global_sksq_weighted:1.1e} gsq_weighted={global_gsq_weighted:1.1e} " f"sk_l1={global_sk_l1:1.1e}{' (FSDP)' if fsdp_in_use else ''}") for group in self.param_groups: group['gsq_weighted'] = gsq_weighted group['d'] = d group_lr = group['lr'] decay = group['weight_decay'] k = group['k'] eps = group['eps'] for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self.state[p] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] state['step'] += 1 denom = exp_avg_sq.sqrt().add_(eps) denom = denom.type(p.type()) # Apply weight decay (decoupled variant) if decay != 0 and decouple and group_lr > 0: p.data.add_(p.data, alpha=-decay * dlr) # Take step p.data.addcdiv_(exp_avg, denom, value=-1) group['k'] = k + 1 return loss ================================================ FILE: audiocraft/optim/ema.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # ModelEMA implementation is taken from # https://github.com/facebookresearch/demucs from collections import defaultdict import typing as tp import torch import torch.nn as nn def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set: names: set = set() for (name, sub_module) in module.named_modules(): if name == '': buffer_names = module._non_persistent_buffers_set buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name for buff_name in buffer_names} names.update(buffer_names) else: sub_name = f"{root}.{name}" if len(root) > 0 else name sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name) names.update(sub_buffer_names) return names def _get_named_tensors(module: nn.Module): non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module) named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers() if name not in non_persistent_buffers_set] named_parameters = list(module.named_parameters()) return named_parameters + named_buffers class ModuleDictEMA: """Exponential Moving Average over a nn.ModuleDict. You can switch to the EMA weights temporarily. """ def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'): self.decay = decay self.module_dict = module_dict self.state: dict = defaultdict(dict) self.count = 0 self.device = device self.unbias = unbias self._init() def _init(self): for module_name, module in self.module_dict.items(): for key, val in _get_named_tensors(module): if not val.is_floating_point(): continue device = self.device or val.device if key not in self.state[module_name]: self.state[module_name][key] = val.detach().to(device, copy=True) def step(self): if self.unbias: self.count = self.count * self.decay + 1 w = 1 / self.count else: w = 1 - self.decay for module_name, module in self.module_dict.items(): for key, val in _get_named_tensors(module): if not val.is_floating_point(): continue device = self.device or val.device self.state[module_name][key].mul_(1 - w) self.state[module_name][key].add_(val.detach().to(device), alpha=w) def state_dict(self): return {'state': self.state, 'count': self.count} def load_state_dict(self, state): self.count = state['count'] for module_name, module in state['state'].items(): for key, val in module.items(): self.state[module_name][key].copy_(val) ================================================ FILE: audiocraft/optim/fsdp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Wrapper around FSDP for more convenient use in the training loops. """ from contextlib import contextmanager import typing as tp import dora import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ( MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) from torch.distributed._shard.sharded_tensor.api import ShardedTensor def is_fsdp_used() -> bool: """Return whether we are using FSDP.""" # A bit of a hack but should work from anywhere. if dora.is_xp(): cfg = dora.get_xp().cfg if hasattr(cfg, 'fsdp'): return cfg.fsdp.use return False def is_sharded_tensor(x: tp.Any) -> bool: return isinstance(x, ShardedTensor) @contextmanager def switch_to_full_state_dict(models: tp.List[FSDP]): # Another bug in FSDP makes it that we cannot use the `state_dict_type` API, # so let's do thing manually. for model in models: FSDP.set_state_dict_type( # type: ignore model, StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) try: yield finally: for model in models: FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore def wrap_with_fsdp(cfg, model: torch.nn.Module, block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: """Wraps a model with FSDP.""" # Some of the typing is disabled until this gets integrated # into the stable version of PyTorch. from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore # we import this here to prevent circular import. from ..modules.transformer import StreamingTransformerLayer from ..modules.conditioners import ConditioningProvider _fix_post_backward_hook() assert cfg.use sharding_strategy_dict = { "no_shard": ShardingStrategy.NO_SHARD, "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, "full_shard": ShardingStrategy.FULL_SHARD, } dtype_dict = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } mixed_precision_config = MixedPrecision( param_dtype=dtype_dict[cfg.param_dtype], reduce_dtype=dtype_dict[cfg.reduce_dtype], buffer_dtype=dtype_dict[cfg.buffer_dtype], ) sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] # The following is going to require being a bit smart # when doing LM, because this would flush the weights for every time step # during generation. One possiblity is to use hybrid sharding: # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ "Not supported at the moment, requires a bit more work." local_rank = dora.distrib.get_distrib_spec().local_rank assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" auto_wrap_policy = None if block_classes is None: block_classes = {StreamingTransformerLayer, ConditioningProvider} if cfg.per_block: auto_wrap_policy = ModuleWrapPolicy(block_classes) wrapped = _FSDPFixStateDict( model, sharding_strategy=sharding_strategy_config, mixed_precision=mixed_precision_config, device_id=local_rank, sync_module_states=True, use_orig_params=True, auto_wrap_policy=auto_wrap_policy, ) # type: ignore FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore # Let the wrapped model know about the wrapping! # We use __dict__ to avoid it going into the state dict. # This is a bit dirty, but needed during generation, as otherwise # the wrapped model would call itself and bypass FSDP. for module in FSDP.fsdp_modules(wrapped): original = module._fsdp_wrapped_module original.__dict__['_fsdp'] = module return wrapped def purge_fsdp(model: FSDP): """Purge the FSDP cached shard inside the model. This should allow setting the best state or switching to the EMA. """ from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore for module in FSDP.fsdp_modules(model): handles = module._handles if not handles: continue handle = handles[0] unsharded_flat_param = handle._get_padded_unsharded_flat_param() storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore if storage_size == 0: continue true_list = [True for h in handles] _reshard(module, handles, true_list) class _FSDPFixStateDict(FSDP): @staticmethod def _name_without_fsdp_prefix(name: str) -> str: from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore parts = name.split('.') new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] return '.'.join(new_parts) def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore state = dict(super().state_dict()) for key, value in list(state.items()): if is_sharded_tensor(value): del state[key] return state def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore if self._state_dict_type is StateDictType.FULL_STATE_DICT: super().load_state_dict(state) purge_fsdp(self) return # Fix FSDP load state dict in all situation. # Use this only with LOCAL_STATE_DICT !!! current_state = dict(super().state_dict()) for key, value in state.items(): key = _FSDPFixStateDict._name_without_fsdp_prefix(key) if key not in current_state: # Emulate strict loading manually. raise RuntimeError(f"Unknown state key {key}") current_state[key].copy_(value) # Purging cached weights from previous forward. purge_fsdp(self) _hook_fixed = False def _fix_post_backward_hook(): global _hook_fixed if _hook_fixed: return _hook_fixed = True from torch.distributed.fsdp import _runtime_utils from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState old_hook = _runtime_utils._post_backward_hook def _post_backward_hook(state, handle, *args, **kwargs): checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) if checkpointed: # there will be one more forward in the backward with checkpointing and that will # massively confuse FSDP, so we have to make it think everything # is going according to the plan. state.training_state = TrainingState.FORWARD_BACKWARD handle._training_state = HandleTrainingState.BACKWARD_PRE old_hook(state, handle, *args, **kwargs) _runtime_utils._post_backward_hook = _post_backward_hook ================================================ FILE: audiocraft/optim/inverse_sqrt_lr_scheduler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler class InverseSquareRootLRScheduler(_LRScheduler): """Inverse square root LR scheduler. Args: optimizer (Optimizer): Torch optimizer. warmup_steps (int): Number of warmup steps. warmup_init_lr (tp.Optional[float]): Initial learning rate during warmup phase. When not set, use the provided learning rate. """ def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): self.warmup_steps = warmup_steps self.warmup_init_lr = warmup_init_lr super().__init__(optimizer) def _get_sched_lr(self, lr: float, step: int): if step < self.warmup_steps: warmup_init_lr = self.warmup_init_lr or 0 lr_step = (lr - warmup_init_lr) / self.warmup_steps lr = warmup_init_lr + step * lr_step else: decay_factor = lr * self.warmup_steps**0.5 lr = decay_factor * step**-0.5 return lr def get_lr(self): return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs] ================================================ FILE: audiocraft/optim/linear_warmup_lr_scheduler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler class LinearWarmupLRScheduler(_LRScheduler): """Inverse square root LR scheduler. Args: optimizer (Optimizer): Torch optimizer. warmup_steps (int): Number of warmup steps. warmup_init_lr (tp.Optional[float]): Initial learning rate during warmup phase. When not set, use the provided learning rate. """ def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): self.warmup_steps = warmup_steps self.warmup_init_lr = warmup_init_lr super().__init__(optimizer) def _get_sched_lr(self, lr: float, step: int): if step < self.warmup_steps: warmup_init_lr = self.warmup_init_lr or 0 lr_step = (lr - warmup_init_lr) / self.warmup_steps lr = warmup_init_lr + step * lr_step return lr def get_lr(self): return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] ================================================ FILE: audiocraft/optim/polynomial_decay_lr_scheduler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler class PolynomialDecayLRScheduler(_LRScheduler): """Polynomial decay LR scheduler. Args: optimizer (Optimizer): Torch optimizer. warmup_steps (int): Number of warmup steps. total_steps (int): Total number of steps. end_lr (float): Final learning rate to achieve over total number of steps. zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0. power (float): Decay exponent. """ def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int, end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.): self.warmup_steps = warmup_steps self.total_steps = total_steps self.end_lr = end_lr self.zero_lr_warmup_steps = zero_lr_warmup_steps self.power = power super().__init__(optimizer) def _get_sched_lr(self, lr: float, step: int): if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps: lr = 0 elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps: lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps) lr = lr_ratio * lr elif step >= self.total_steps: lr = self.end_lr else: total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps lr_range = lr - self.end_lr pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps) lr = lr_range * pct_remaining ** self.power + self.end_lr return lr def get_lr(self): return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] ================================================ FILE: audiocraft/py.typed ================================================ ================================================ FILE: audiocraft/quantization/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """RVQ.""" # flake8: noqa from .vq import ResidualVectorQuantizer from .base import BaseQuantizer, DummyQuantizer, QuantizedResult ================================================ FILE: audiocraft/quantization/base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Base class for all quantizers. """ from dataclasses import dataclass, field import typing as tp import torch from torch import nn @dataclass class QuantizedResult: x: torch.Tensor codes: torch.Tensor bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. penalty: tp.Optional[torch.Tensor] = None metrics: dict = field(default_factory=dict) class BaseQuantizer(nn.Module): """Base class for quantizers. """ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: """ Given input tensor x, returns first the quantized (or approximately quantized) representation along with quantized codes, bandwidth, and any penalty term for the loss. Finally, this returns a dict of metrics to update logging etc. Frame rate must be passed so that the bandwidth is properly computed. """ raise NotImplementedError() def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth.""" raise NotImplementedError() def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decode the given codes to the quantized representation.""" raise NotImplementedError() @property def total_codebooks(self): """Total number of codebooks.""" raise NotImplementedError() @property def num_codebooks(self): """Number of active codebooks.""" raise NotImplementedError() def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" raise NotImplementedError() class DummyQuantizer(BaseQuantizer): """Fake quantizer that actually does not perform any quantization. """ def __init__(self): super().__init__() def forward(self, x: torch.Tensor, frame_rate: int): q = x.unsqueeze(1) return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. In the case of the DummyQuantizer, the codes are actually identical to the input and resulting quantized representation as no quantization is done. """ return x.unsqueeze(1) def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decode the given codes to the quantized representation. In the case of the DummyQuantizer, the codes are actually identical to the input and resulting quantized representation as no quantization is done. """ return codes.squeeze(1) @property def total_codebooks(self): """Total number of codebooks.""" return 1 @property def num_codebooks(self): """Total number of codebooks.""" return self.total_codebooks def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") ================================================ FILE: audiocraft/quantization/core_vq.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp from einops import rearrange, repeat import flashy import torch from torch import nn, einsum import torch.nn.functional as F def exists(val: tp.Optional[tp.Any]) -> bool: return val is not None def default(val: tp.Any, d: tp.Any) -> tp.Any: return val if exists(val) else d def l2norm(t): return F.normalize(t, p=2, dim=-1) def ema_inplace(moving_avg, new, decay: float): moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): return (x + epsilon) / (x.sum() + n_categories * epsilon) def uniform_init(*shape: int): t = torch.empty(shape) nn.init.kaiming_uniform_(t) return t def sample_vectors(samples, num: int): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) return samples[indices] def kmeans(samples, num_clusters: int, num_iters: int = 10): dim, dtype = samples.shape[-1], samples.dtype means = sample_vectors(samples, num_clusters) for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange( means, "c d -> () c d" ) dists = -(diffs ** 2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) new_means = new_means / bins_min_clamped[..., None] means = torch.where(zero_mask[..., None], means, new_means) return means, bins def orthogonal_loss_fn(t): # eq (2) from https://arxiv.org/abs/2112.00384 n = t.shape[0] normed_codes = l2norm(t) identity = torch.eye(n, device=t.device) cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) return ((cosine_sim - identity) ** 2).sum() / (n ** 2) class EuclideanCodebook(nn.Module): """Codebook with Euclidean distance. Args: dim (int): Dimension. codebook_size (int): Codebook size. kmeans_init (bool): Whether to use k-means to initialize the codebooks. If set to true, run the k-means algorithm on the first training batch and use the learned centroids as initialization. kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.8, epsilon: float = 1e-5, threshold_ema_dead_code: int = 2, ): super().__init__() self.decay = decay init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.epsilon = epsilon self.threshold_ema_dead_code = threshold_ema_dead_code self.register_buffer("inited", torch.Tensor([not kmeans_init])) self.register_buffer("cluster_size", torch.zeros(codebook_size)) self.register_buffer("embed", embed) self.register_buffer("embed_avg", embed.clone()) @torch.jit.ignore def init_embed_(self, data): if self.inited: return embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) # Make sure all buffers across workers are in sync after initialization flashy.distrib.broadcast_tensors(self.buffers()) def replace_(self, samples, mask): modified_codebook = torch.where( mask[..., None], sample_vectors(samples, self.codebook_size), self.embed ) self.embed.data.copy_(modified_codebook) def expire_codes_(self, batch_samples): if self.threshold_ema_dead_code == 0: return expired_codes = self.cluster_size < self.threshold_ema_dead_code if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace_(batch_samples, mask=expired_codes) flashy.distrib.broadcast_tensors(self.buffers()) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") return x def quantize(self, x): embed = self.embed.t() dist = -( x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices return embed_ind def postprocess_emb(self, embed_ind, shape): return embed_ind.view(*shape[:-1]) def dequantize(self, embed_ind): quantize = F.embedding(embed_ind, self.embed) return quantize def encode(self, x): shape = x.shape # pre-process x = self.preprocess(x) # quantize embed_ind = self.quantize(x) # post-process embed_ind = self.postprocess_emb(embed_ind, shape) return embed_ind def decode(self, embed_ind): quantize = self.dequantize(embed_ind) return quantize def forward(self, x): shape, dtype = x.shape, x.dtype x = self.preprocess(x) self.init_embed_(x) embed_ind = self.quantize(x) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = self.postprocess_emb(embed_ind, shape) quantize = self.dequantize(embed_ind) if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) return quantize, embed_ind class VectorQuantization(nn.Module): """Vector quantization implementation. Currently supports only euclidean distance. Args: dim (int): Dimension codebook_size (int): Codebook size codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): channels_last (bool): Channels are the last dimension in the input tensors. commitment_weight (float): Weight for commitment loss. orthogonal_reg_weight (float): Orthogonal regularization weights. orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. orthogonal_reg_max_codes (optional int): Maximum number of codes to consider for orthogonal regularization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dim: int, codebook_size: int, codebook_dim: tp.Optional[int] = None, decay: float = 0.8, epsilon: float = 1e-5, kmeans_init: bool = False, kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, channels_last: bool = False, commitment_weight: float = 1., orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: tp.Optional[int] = None, ): super().__init__() _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) self.epsilon = epsilon self.commitment_weight = commitment_weight self.orthogonal_reg_weight = orthogonal_reg_weight self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code) self.codebook_size = codebook_size self.channels_last = channels_last @property def codebook(self): return self._codebook.embed @property def inited(self): return self._codebook.inited def _preprocess(self, x): if not self.channels_last: x = rearrange(x, "b d n -> b n d") return x def _postprocess(self, quantize): if not self.channels_last: quantize = rearrange(quantize, "b n d -> b d n") return quantize def encode(self, x): x = self._preprocess(x) x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in def decode(self, embed_ind): quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) quantize = self._postprocess(quantize) return quantize def forward(self, x): device = x.device x = self._preprocess(x) x = self.project_in(x) quantize, embed_ind = self._codebook(x) if self.training: quantize = x + (quantize - x).detach() loss = torch.tensor([0.0], device=device, requires_grad=self.training) if self.training: if self.commitment_weight > 0: commit_loss = F.mse_loss(quantize.detach(), x) loss = loss + commit_loss * self.commitment_weight if self.orthogonal_reg_weight > 0: codebook = self.codebook if self.orthogonal_reg_active_codes_only: # only calculate orthogonal loss for the activated codes for this batch unique_code_ids = torch.unique(embed_ind) codebook = codebook[unique_code_ids] num_codes = codebook.shape[0] if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] codebook = codebook[rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight quantize = self.project_out(quantize) quantize = self._postprocess(quantize) return quantize, embed_ind, loss class ResidualVectorQuantization(nn.Module): """Residual vector quantization implementation. Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ def __init__(self, *, num_quantizers, **kwargs): super().__init__() self.layers = nn.ModuleList( [VectorQuantization(**kwargs) for _ in range(num_quantizers)] ) def forward(self, x, n_q: tp.Optional[int] = None): quantized_out = 0.0 residual = x all_losses = [] all_indices = [] n_q = n_q or len(self.layers) for i, layer in enumerate(self.layers[:n_q]): quantized, indices, loss = layer(residual) residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized all_indices.append(indices) out_indices = torch.stack(all_indices) return out_indices def decode(self, q_indices: torch.Tensor) -> torch.Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices): layer = self.layers[i] quantized = layer.decode(indices) quantized_out = quantized_out + quantized return quantized_out ================================================ FILE: audiocraft/quantization/vq.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import typing as tp import torch from .base import BaseQuantizer, QuantizedResult from .core_vq import ResidualVectorQuantization class ResidualVectorQuantizer(BaseQuantizer): """Residual Vector Quantizer. Args: dimension (int): Dimension of the codebooks. n_q (int): Number of residual vector quantizers used. q_dropout (bool): Random quantizer drop out at train time. bins (int): Codebook size. decay (float): Decay for exponential moving average over the codebooks. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. orthogonal_reg_weight (float): Orthogonal regularization weights. orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. for orthogonal regularization. """ def __init__( self, dimension: int = 256, n_q: int = 8, q_dropout: bool = False, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: tp.Optional[int] = None, ): super().__init__() self.max_n_q = n_q self.n_q = n_q self.q_dropout = q_dropout self.dimension = dimension self.bins = bins self.decay = decay self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.threshold_ema_dead_code = threshold_ema_dead_code self.orthogonal_reg_weight = orthogonal_reg_weight self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes self.vq = ResidualVectorQuantization( dim=self.dimension, codebook_size=self.bins, num_quantizers=self.n_q, decay=self.decay, kmeans_init=self.kmeans_init, kmeans_iters=self.kmeans_iters, threshold_ema_dead_code=self.threshold_ema_dead_code, orthogonal_reg_weight=self.orthogonal_reg_weight, orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, channels_last=False ) def forward(self, x: torch.Tensor, frame_rate: int): n_q = self.n_q if self.training and self.q_dropout: n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) bw_per_q = math.log2(self.bins) * frame_rate / 1000 quantized, codes, commit_loss = self.vq(x, n_q=n_q) codes = codes.transpose(0, 1) # codes is [B, K, T], with T frames, K nb of codebooks. bw = torch.tensor(n_q * bw_per_q).to(x) return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a given input tensor with the specified frame rate at the given bandwidth. The RVQ encode method sets the appropriate number of quantizer to use and returns indices for each quantizer. """ n_q = self.n_q codes = self.vq.encode(x, n_q=n_q) codes = codes.transpose(0, 1) # codes is [B, K, T], with T frames, K nb of codebooks. return codes def decode(self, codes: torch.Tensor) -> torch.Tensor: """Decode the given codes to the quantized representation.""" # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. codes = codes.transpose(0, 1) quantized = self.vq.decode(codes) return quantized @property def total_codebooks(self): return self.max_n_q @property def num_codebooks(self): return self.n_q def set_num_codebooks(self, n: int): assert n > 0 and n <= self.max_n_q self.n_q = n ================================================ FILE: audiocraft/solvers/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Solvers. A Solver is a training recipe, combining the dataloaders, models, optimizer, losses etc into a single convenient object. """ # flake8: noqa from .audiogen import AudioGenSolver from .builders import get_solver from .base import StandardSolver from .compression import CompressionSolver from .musicgen import MusicGenSolver from .diffusion import DiffusionSolver ================================================ FILE: audiocraft/solvers/audiogen.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from . import builders, musicgen class AudioGenSolver(musicgen.MusicGenSolver): """Solver for AudioGen re-implementation training task. Note that this implementation does not strictly follows the method proposed in https://arxiv.org/abs/2209.15352 but is derived from MusicGen's training pipeline. More information can be found in the AudioGen model card. """ DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND ================================================ FILE: audiocraft/solvers/base.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod from contextlib import contextmanager from pathlib import Path import typing as tp import flashy import omegaconf import torch from torch import nn from .. import optim from ..optim import fsdp from ..utils import checkpoint from ..utils.autocast import TorchAutocast from ..utils.best_state import BestStateDictManager from ..utils.deadlock import DeadlockDetect from ..utils.profiler import Profiler from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng class StandardSolver(ABC, flashy.BaseSolver): """Standard solver for AudioCraft. The standard solver implements a base training loop with the following stages: train, valid, evaluate and generate that are expected to be all defined for solvers in AudioCraft. It also provides a nice default management of Dora history replay, checkpoint management across epoch, and logging configuration. AudioCraft solvers must inherit from the StandardSolver and define the methods associated to each stage as well as the show, build_model and build_dataloaders methods. """ def __init__(self, cfg: omegaconf.DictConfig): super().__init__() self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}") self.logger.info(f"All XP logs are stored in {self.xp.folder}") self.cfg = cfg self.device = cfg.device self.model: nn.Module self._continue_best_source_keys = ['best_state', 'fsdp_best_state'] self._fsdp_modules: tp.List[fsdp.FSDP] = [] self._ema_sources: nn.ModuleDict = nn.ModuleDict() self.ema: tp.Optional[optim.ModuleDictEMA] = None self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict() self._log_updates = self.cfg.logging.get('log_updates', 10) if self.cfg.logging.log_tensorboard: self.init_tensorboard(**self.cfg.get('tensorboard')) if self.cfg.logging.log_wandb and self: self.init_wandb(**self.cfg.get('wandb')) # keep a copy of the best performing state for stateful objects # used for evaluation and generation stages dtype_best: tp.Optional[torch.dtype] = None if self.cfg.fsdp.use: dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore assert isinstance(dtype_best, torch.dtype) elif self.cfg.autocast: dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore assert isinstance(dtype_best, torch.dtype) self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best) # Hacky support for keeping a copy of the full best state in rank0. self.fsdp_best_state: tp.Dict[str, tp.Any] = {} self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict self._new_best_state: bool = False # should save a new checkpoint # instantiate datasets and appropriate number of updates per epoch self.build_dataloaders() if self.cfg.execute_only is None: assert 'train' in self.dataloaders, "The train dataset split must be provided." assert 'valid' in self.dataloaders, "The valid dataset split must be provided." self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0 if self.cfg.optim.updates_per_epoch: self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs # instantiate model & exponential moving average on the model self.build_model() self.logger.info("Model hash: %s", model_hash(self.model)) assert 'model' in self.stateful.sources, \ "Please register the model to stateful with self.register_stateful('model') in build_model." self.profiler = Profiler(self.model, **self.cfg.profiler) self.initialize_ema() self.register_stateful('ema') assert self.ema is None or 'ema' in self.stateful.sources, \ "Please register the ema to stateful with self.register_stateful('ema') in build_model." self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock) # basic statistics on the trained model model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6 # one copy of grad, one copy of momentum, one copy of denominator and model weights. # and 4 bytes for each float! mem_usage = model_size * 4 * 4 / 1000 self.logger.info("Model size: %.2f M params", model_size) self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage) @property def autocast(self): """Convenient autocast (or not) using the solver configuration.""" return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype) def _get_state_source(self, name) -> flashy.state.StateDictSource: # Internal utility to get a state source from the solver return self.stateful.sources[name] @property def best_metric_name(self) -> tp.Optional[str]: """Metric name used to identify the best state. This metric should be stored in the metrics used on the stage for best state identification (most likely, `valid`). If None, then no best state is saved. """ return None def register_best_state(self, *args: str): """Register state sources in `BestStateDictManager` to keep their best states along with their latest states. The best state will be used at evaluation stages instead of the latest states. Shortcut around `BestStateDictManager.register` method. You can pass any number of attribute, included nested attributes and those will be included into the checkpoints and automatically restored when `BaseSolver.restore` is called. """ for name in args: state_source = self._get_state_source(name) assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!" self.best_state.register(name, state_source) def register_ema(self, *args: str): """Register state sources for exponential moving average. The registered sources are used to instantiate a ModuleDictEMA instance. The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called and swapped with the original state sources with self.swap_ema_state() method. Usage: self.register_ema('model') """ assert self.ema is None, "Cannot register state source to already instantiated EMA." for name in args: self._ema_sources[name] = getattr(self, name) def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs): model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs) if isinstance(model, fsdp.FSDP): self._fsdp_modules.append(model) return model def update_best_state_from_stage(self, stage_name: str = 'valid'): """Update latest best state based on pending metrics of a given stage. This method relies on the `BestStateDictManager.update` method to update the best state_dict with latest weights if the registered states happen to match to the best performing setup. """ if self.best_metric_name is None: # when no best metric is defined, the last state is always the best self._new_best_state = True self.logger.info("Updating best state with current state.") else: assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found." assert self.best_metric_name in self._pending_metrics[stage_name], \ f"Best metric not found in {stage_name} metrics. Cannot register best state" current_score = self._pending_metrics[stage_name][self.best_metric_name] all_best_metric_scores = [ past_metrics[stage_name][self.best_metric_name] for past_metrics in self.history ] all_best_metric_scores.append(current_score) best_score = min(all_best_metric_scores) self._new_best_state = current_score == best_score if self._new_best_state: old_best = min(all_best_metric_scores[:-1] + [float('inf')]) self.logger.info( f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})") if self._new_best_state: if self.cfg.fsdp.use: # this will give an empty state dict on all ranks but the rank 0 # which will have a copy in memory of the full model. with fsdp.switch_to_full_state_dict(self._fsdp_modules): for name in self.best_state.states.keys(): state_source = self._get_state_source(name) self.best_state.update(name, state_source) # we save to a different dict. self.fsdp_best_state.update(self.best_state.state_dict()) # We cannot efficiently load fsdp_best_state when using FSDP, # so we have do do a second pass, with the local shards. for name in self.best_state.states.keys(): state_source = self._get_state_source(name) self.best_state.update(name, state_source) def _load_new_state_dict(self, state_dict: dict) -> dict: old_states = {} for name, new_state in state_dict.items(): state_source = self._get_state_source(name) old_states[name] = copy_state(state_source.state_dict()) state_source.load_state_dict(new_state) return old_states @contextmanager def swap_best_state(self): self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}") old_states = self._load_new_state_dict(self.best_state.state_dict()) try: yield finally: self.logger.debug("Swapping back from best to original state") for name, old_state in old_states.items(): state_source = self._get_state_source(name) state_source.load_state_dict(old_state) @contextmanager def swap_ema_state(self): if self.ema is None: yield else: ema_state_dict = self.ema.state_dict()['state'] self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}") old_states = self._load_new_state_dict(ema_state_dict) try: yield finally: self.logger.debug("Swapping back from EMA state to original state") for name, old_state in old_states.items(): state_source = self._get_state_source(name) state_source.load_state_dict(old_state) @property def is_training(self): return self.current_stage == 'train' def log_model_summary(self, model: nn.Module): """Log model summary, architecture and size of the model.""" self.logger.info(model) mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20 self.logger.info("Size: %.1f MB", mb) @abstractmethod def build_model(self): """Method to implement to initialize model.""" ... def initialize_ema(self): """Initialize exponential moving average with the registered sources. EMA object is created if the optim.ema.model.decay value is non-null. """ from .builders import get_ema self.ema = get_ema(self._ema_sources, self.cfg.optim.ema) if self.ema is None: self.logger.info('No EMA on the model.') else: assert self.cfg.optim.ema.updates > 0 self.logger.info( f'Initializing EMA on the model with decay = {self.ema.decay}' f' every {self.cfg.optim.ema.updates} updates' ) @abstractmethod def build_dataloaders(self): """Method to implement to initialize dataloaders.""" ... @abstractmethod def show(self): """Method to log any information without running the job.""" ... @property def log_updates(self): # convenient access to log updates return self._log_updates def checkpoint_path(self, **kwargs): kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) return self.folder / checkpoint.checkpoint_name(**kwargs) def epoch_checkpoint_path(self, epoch: int, **kwargs): kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs) def checkpoint_path_with_name(self, name: str, **kwargs): kwargs.setdefault('use_fsdp', self.cfg.fsdp.use) return self.folder / checkpoint.checkpoint_name(name=name, **kwargs) def save_checkpoints(self): """Save checkpoint, optionally keeping a copy for a given epoch.""" is_sharded = self.cfg.fsdp.use if not flashy.distrib.is_rank_zero() and not is_sharded: return self.logger.info("Model hash: %s", model_hash(self.model)) state = self.state_dict() epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here # save minimal state_dict as new checkpoint every X epoch if self.cfg.checkpoint.save_every: if epoch % self.cfg.checkpoint.save_every == 0: minimal_state = state if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0: minimal_state = { name: source for name, source in state.items() if name in self.cfg.checkpoint.keep_every_states } epoch_checkpoint_path = self.epoch_checkpoint_path(epoch) checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded) # save checkpoint as latest checkpoint if self.cfg.checkpoint.save_last: last_checkpoint_path = self.checkpoint_path() checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded) # flush any stale checkpoint to reduce disk footprint checkpoint.flush_stale_checkpoints(self.checkpoint_path()) def load_from_pretrained(self, name: str) -> dict: raise NotImplementedError("Solver does not provide a way to load pretrained models.") def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]: """Load last checkpoint or the one specified in continue_from. Args: load_best (bool): Whether to load from best state dict or not. Best state dict is always used when not loading the current xp. ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`. Returns: state (dict, optional): The loaded state dictionary. """ # load checkpoints from xp folder or cfg.continue_from is_sharded = self.cfg.fsdp.use load_from_path: tp.Optional[Path] = None checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None if load_best: self.logger.info("Trying to load state_dict from best state.") state: tp.Optional[dict] = None rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False) current_checkpoint_path = self.checkpoint_path() _pretrained_prefix = '//pretrained/' continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix) if rank0_checkpoint_path.exists(): self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}") load_from_path = current_checkpoint_path checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path) checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP elif self.cfg.continue_from and not continue_pretrained: self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}") # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False) if load_from_path is None: self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from) raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}') checkpoint_source = checkpoint.CheckpointSource.OTHER if load_from_path is not None: state = checkpoint.load_checkpoint(load_from_path, is_sharded) elif continue_pretrained: self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.") state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):]) checkpoint_source = checkpoint.CheckpointSource.PRETRAINED load_best = True # checkpoints are not from the current xp, we only retrieve the best state if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP: assert state is not None self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.") load_best = True state = {key: state[key] for key in self._continue_best_source_keys if key in state} # loaded checkpoints are FSDP checkpoints: we're reading the best state # from FSDP and we drop the regular best_state if 'fsdp_best_state' in state and state['fsdp_best_state']: state.pop('best_state', None) self.logger.info("... Loaded checkpoint has FSDP best state") # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support # then we're initializing FSDP best state with the regular best state elif self.cfg.fsdp.use: if 'fsdp_best_state' not in state or not state['fsdp_best_state']: # we swap non-FSDP checkpoints best_state to FSDP-compatible best state state['fsdp_best_state'] = state.pop('best_state') self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state") if state is not None: if load_best: self.logger.info("Ignoring keys when loading best %r", ignore_state_keys) for key in set(ignore_state_keys): if key in state: state.pop(key) has_best_state = 'best_state' in state or 'fsdp_best_state' in state assert has_best_state, ("Trying to load best state but neither 'best_state'", " or 'fsdp_best_state' found in checkpoints.") self.load_state_dict(state) # for FSDP, let's make extra sure nothing bad happened with out of sync # checkpoints across workers. epoch = float(self.epoch) avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch'] if avg_epoch != epoch: raise RuntimeError( f"Inconsistent loading of checkpoints happened, our epoch is {epoch} " f"but average of epochs is {avg_epoch}, at least one gpu must have a " "different epoch number.") # on load_best, properly reinitialize state_dict, best states and ema # otherwise we load from the current xp and don't alter anything if load_best: self.logger.info("Loading state_dict from best state.") if not self.cfg.fsdp.use and self.fsdp_best_state: # loading from an FSDP checkpoint but with FSDP deactivated self.logger.info("... Loading from FSDP best state dict.") self.best_state.load_state_dict(self.fsdp_best_state) # if load_best, we permanently override the regular state_dict with the best state if self.cfg.fsdp.use: self.logger.info("FSDP is used, loading from FSDP best state.") with fsdp.switch_to_full_state_dict(self._fsdp_modules): # this might be really fragile but okay for now. self.load_state_dict(self.fsdp_best_state) else: # we permanently swap the stateful objects to their best state self._load_new_state_dict(self.best_state.state_dict()) # the EMA modules should also be instantiated with best state. # the easiest way to do so is to reinitialize a new EMA with best state loaded. if self.ema is not None: self.logger.info("Re-initializing EMA from best state") self.initialize_ema() if self.cfg.fsdp.use: self.logger.info("Re-initializing best state after using FSDP best state.") for name in self.best_state.states.keys(): state_source = self._get_state_source(name) self.best_state.update(name, state_source) return state def restore(self, load_best: bool = False, replay_metrics: bool = False, ignore_state_keys: tp.List[str] = []) -> bool: """Restore the status of a solver for a given xp. Args: load_best (bool): if `True`, load the best state from the checkpoint. replay_metrics (bool): if `True`, logs all the metrics from past epochs. ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`. """ self.logger.info("Restoring weights and history.") restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys) self.logger.info("Model hash: %s", model_hash(self.model)) if replay_metrics and len(self.history) > 0: self.logger.info("Replaying past metrics...") for epoch, stages in enumerate(self.history): for stage_name, metrics in stages.items(): # We manually log the metrics summary to the result logger # as we don't want to add them to the pending metrics self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch', formatter=self.get_formatter(stage_name)) return restored_checkpoints is not None def commit(self, save_checkpoints: bool = True): """Commit metrics to dora and save checkpoints at the end of an epoch.""" # we override commit to introduce more complex checkpoint saving behaviors self.history.append(self._pending_metrics) # This will increase self.epoch if save_checkpoints: self.save_checkpoints() self._start_epoch() if flashy.distrib.is_rank_zero(): self.xp.link.update_history(self.history) def run_epoch(self): """Run a single epoch with all stages. Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards. Children solvers can extend this method with custom behavior, e.g.: def run_epoch(self): ... # custom code super().run_epoch() ... # custom code """ self.run_stage('train', self.train) with torch.no_grad(): with self.swap_ema_state(): self.run_stage('valid', self.valid) # the best state is updated with EMA states if available self.update_best_state_from_stage('valid') with self.swap_best_state(): if self.should_run_stage('evaluate'): self.run_stage('evaluate', self.evaluate) if self.should_run_stage('generate'): self.run_stage('generate', with_rank_rng()(self.generate)) def run(self): """Training loop.""" assert len(self.state_dict()) > 0 self.restore(replay_metrics=True) # load checkpoint and replay history self.log_hyperparams(dict_from_config(self.cfg)) for epoch in range(self.epoch, self.cfg.optim.epochs + 1): if self.should_stop_training(): return self.run_epoch() # Commit will send the metrics to Dora and save checkpoints by default. self.commit() def should_stop_training(self) -> bool: """Check whether we should stop training or not.""" return self.epoch > self.cfg.optim.epochs def should_run_stage(self, stage_name) -> bool: """Check whether we want to run the specified stages.""" stage_every = self.cfg[stage_name].get('every', None) is_last_epoch = self.epoch == self.cfg.optim.epochs is_epoch_every = (stage_every and self.epoch % stage_every == 0) return is_last_epoch or is_epoch_every @abstractmethod def run_step(self, idx: int, batch: tp.Any, metrics: dict): """Perform one training or valid step on a given batch.""" ... def common_train_valid(self, dataset_split: str, **kwargs: tp.Any): """Common logic for train and valid stages.""" self.model.train(self.is_training) loader = self.dataloaders[dataset_split] # get a different order for distributed training, otherwise this will get ignored if flashy.distrib.world_size() > 1 \ and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler): loader.sampler.set_epoch(self.epoch) updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader) if self.cfg.benchmark_no_load: self.logger.warning("Fake loading for benchmarking: re-using first batch") batch = next(iter(loader)) loader = [batch] * updates_per_epoch # type: ignore lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates) average = flashy.averager() # epoch wise average instant_average = flashy.averager() # average between two logging metrics: dict = {} with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates. for idx, batch in enumerate(lp): self.deadlock_detect.update('batch') if idx >= updates_per_epoch: break metrics = {} metrics = self.run_step(idx, batch, metrics) self.deadlock_detect.update('step') # run EMA step if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0: self.logger.debug("EMA model step") self.ema.step() self.deadlock_detect.update('ema') self.profiler.step() instant_metrics = instant_average(metrics) if lp.update(**instant_metrics): instant_average = flashy.averager() # reset averager between two logging metrics = average(metrics) # epoch wise average self.deadlock_detect.update('end_batch') metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch) return metrics def train(self): """Train stage.""" return self.common_train_valid('train') def valid(self): """Valid stage.""" return self.common_train_valid('valid') @abstractmethod def evaluate(self): """Evaluate stage.""" ... @abstractmethod def generate(self): """Generate stage.""" ... def run_one_stage(self, stage_name: str): """Run only the specified stage. This method is useful to only generate samples from a trained experiment or rerun the validation or evaluation stages. """ fn = { 'generate': with_rank_rng()(self.generate), 'evaluate': self.evaluate, 'valid': self.valid, } if stage_name not in fn: raise ValueError(f'Trying to run stage {stage_name} is not supported.') assert len(self.state_dict()) > 0 self._start_epoch() with torch.no_grad(), self.swap_best_state(): self.run_stage(stage_name, fn[stage_name]) if not self.cfg.execute_inplace: self.commit(save_checkpoints=False) @staticmethod def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, device: tp.Optional[str] = None, autocast: bool = True, batch_size: tp.Optional[int] = None, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, **kwargs): """Mostly a convenience function around audiocraft.train.get_solver_from_sig, populating all the proper param, deactivating EMA, FSDP, loading the best state, basically all you need to get a solver ready to "play" with in single GPU mode and with minimal memory overhead. Args: sig (str): signature to load. dtype (str or None): potential dtype, as a string, i.e. 'float16'. device (str or None): potential device, as a string, i.e. 'cuda'. override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. """ from audiocraft import train our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} our_override_cfg['autocast'] = autocast if dtype is not None: our_override_cfg['dtype'] = dtype if device is not None: our_override_cfg['device'] = device if batch_size is not None: our_override_cfg['dataset'] = {'batch_size': batch_size} if override_cfg is None: override_cfg = {} override_cfg = omegaconf.OmegaConf.merge( omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore solver = train.get_solver_from_sig( sig, override_cfg=override_cfg, load_best=True, disable_fsdp=True, ignore_state_keys=['optimizer', 'ema'], **kwargs) solver.model.eval() return solver ================================================ FILE: audiocraft/solvers/builders.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ All the functions to build the relevant solvers and used objects from the Hydra config. """ from enum import Enum import logging import typing as tp import dora import flashy import omegaconf import torch from torch import nn from torch.optim import Optimizer # LRScheduler was renamed in some torch versions try: from torch.optim.lr_scheduler import LRScheduler # type: ignore except ImportError: from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from .base import StandardSolver from .. import adversarial, data, losses, metrics, optim from ..utils.utils import dict_from_config, get_loader logger = logging.getLogger(__name__) class DatasetType(Enum): AUDIO = "audio" MUSIC = "music" SOUND = "sound" def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver: """Instantiate solver from config.""" from .audiogen import AudioGenSolver from .compression import CompressionSolver from .musicgen import MusicGenSolver from .diffusion import DiffusionSolver klass = { 'compression': CompressionSolver, 'musicgen': MusicGenSolver, 'audiogen': AudioGenSolver, 'lm': MusicGenSolver, # backward compatibility 'diffusion': DiffusionSolver, 'sound_lm': AudioGenSolver, # backward compatibility }[cfg.solver] return klass(cfg) # type: ignore def get_optim_parameter_groups(model: nn.Module): """Create parameter groups for the model using the appropriate method if defined for each modules, to create the different groups. Args: model (nn.Module): torch model Returns: List of parameter groups """ seen_params: tp.Set[nn.parameter.Parameter] = set() other_params = [] groups = [] for name, module in model.named_modules(): if hasattr(module, 'make_optim_group'): group = module.make_optim_group() params = set(group['params']) assert params.isdisjoint(seen_params) seen_params |= set(params) groups.append(group) for param in model.parameters(): if param not in seen_params: other_params.append(param) groups.insert(0, {'params': other_params}) parameters = groups return parameters def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer: """Build torch optimizer from config and set of parameters. Supported optimizers: Adam, AdamW Args: params (nn.Module or iterable of torch.Tensor): Parameters to optimize. cfg (DictConfig): Optimization-related configuration. Returns: torch.optim.Optimizer. """ if 'optimizer' not in cfg: if getattr(cfg, 'optim', None) is not None: raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?") else: raise KeyError("Optimizer not found in config.") parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params optimizer: torch.optim.Optimizer if cfg.optimizer == 'adam': optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam) elif cfg.optimizer == 'adamw': optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam) elif cfg.optimizer == 'dadam': optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam) else: raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") return optimizer def get_lr_scheduler(optimizer: torch.optim.Optimizer, cfg: omegaconf.DictConfig, total_updates: int) -> tp.Optional[LRScheduler]: """Build torch learning rate scheduler from config and associated optimizer. Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler Args: optimizer (torch.optim.Optimizer): Optimizer. cfg (DictConfig): Schedule-related configuration. total_updates (int): Total number of updates. Returns: torch.optim.Optimizer. """ if 'lr_scheduler' not in cfg: raise KeyError("LR Scheduler not found in config") lr_sched: tp.Optional[LRScheduler] = None if cfg.lr_scheduler == 'step': lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step) elif cfg.lr_scheduler == 'exponential': lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential) elif cfg.lr_scheduler == 'cosine': kwargs = dict_from_config(cfg.cosine) warmup_steps = kwargs.pop('warmup') lr_sched = optim.CosineLRScheduler( optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) elif cfg.lr_scheduler == 'polynomial_decay': kwargs = dict_from_config(cfg.polynomial_decay) warmup_steps = kwargs.pop('warmup') lr_sched = optim.PolynomialDecayLRScheduler( optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs) elif cfg.lr_scheduler == 'inverse_sqrt': kwargs = dict_from_config(cfg.inverse_sqrt) warmup_steps = kwargs.pop('warmup') lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) elif cfg.lr_scheduler == 'linear_warmup': kwargs = dict_from_config(cfg.linear_warmup) warmup_steps = kwargs.pop('warmup') lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs) elif cfg.lr_scheduler is not None: raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}") return lr_sched def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]: """Initialize Exponential Moving Average. Args: module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA. cfg (omegaconf.DictConfig): Optim EMA configuration. Returns: optim.ModuleDictEMA: EMA version of the ModuleDict. """ kw: tp.Dict[str, tp.Any] = dict(cfg) use = kw.pop('use', False) decay = kw.pop('decay', None) device = kw.pop('device', None) if not use: return None if len(module_dict) == 0: raise ValueError("Trying to build EMA but an empty module_dict source is provided!") ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device) return ema_module def get_loss(loss_name: str, cfg: omegaconf.DictConfig): """Instantiate loss from configuration.""" klass = { 'l1': torch.nn.L1Loss, 'l2': torch.nn.MSELoss, 'mel': losses.MelSpectrogramL1Loss, 'mrstft': losses.MRSTFTLoss, 'msspec': losses.MultiScaleMelSpectrogramLoss, 'sisnr': losses.SISNR, }[loss_name] kwargs = dict(getattr(cfg, loss_name)) return klass(**kwargs) def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer: """Instantiate loss balancer from configuration for the provided weights.""" kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg) return losses.Balancer(loss_weights, **kwargs) def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module: """Initialize adversary from config.""" klass = { 'msd': adversarial.MultiScaleDiscriminator, 'mpd': adversarial.MultiPeriodDiscriminator, 'msstftd': adversarial.MultiScaleSTFTDiscriminator, }[name] adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name)) return klass(**adv_cfg) def get_adversarial_losses(cfg) -> nn.ModuleDict: """Initialize dict of adversarial losses from config.""" device = cfg.device adv_cfg = getattr(cfg, 'adversarial') adversaries = adv_cfg.get('adversaries', []) adv_loss_name = adv_cfg['adv_loss'] feat_loss_name = adv_cfg.get('feat_loss') normalize = adv_cfg.get('normalize', True) feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None if feat_loss_name: assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found." loss = get_loss(feat_loss_name, cfg) feat_loss = adversarial.FeatureMatchingLoss(loss, normalize) loss = adversarial.get_adv_criterion(adv_loss_name) loss_real = adversarial.get_real_criterion(adv_loss_name) loss_fake = adversarial.get_fake_criterion(adv_loss_name) adv_losses = nn.ModuleDict() for adv_name in adversaries: adversary = get_adversary(adv_name, cfg).to(device) optimizer = get_optimizer(adversary.parameters(), cfg.optim) adv_loss = adversarial.AdversarialLoss( adversary, optimizer, loss=loss, loss_real=loss_real, loss_fake=loss_fake, loss_feat=feat_loss, normalize=normalize ) adv_losses[adv_name] = adv_loss return adv_losses def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL: """Instantiate ViSQOL metric from config.""" kwargs = dict_from_config(cfg) return metrics.ViSQOL(**kwargs) def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric: """Instantiate Frechet Audio Distance metric from config.""" kwargs = dict_from_config(cfg.tf) xp = dora.get_xp() kwargs['log_folder'] = xp.folder return metrics.FrechetAudioDistanceMetric(**kwargs) def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric: """Instantiate KL-Divergence metric from config.""" kld_metrics = { 'passt': metrics.PasstKLDivergenceMetric, } klass = kld_metrics[cfg.model] kwargs = dict_from_config(cfg.get(cfg.model)) return klass(**kwargs) def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric: """Instantiate Text Consistency metric from config.""" text_consistency_metrics = { 'clap': metrics.CLAPTextConsistencyMetric } klass = text_consistency_metrics[cfg.model] kwargs = dict_from_config(cfg.get(cfg.model)) return klass(**kwargs) def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric: """Instantiate Chroma Cosine Similarity metric from config.""" assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric" kwargs = dict_from_config(cfg.get(cfg.model)) return metrics.ChromaCosineSimilarityMetric(**kwargs) def get_audio_datasets(cfg: omegaconf.DictConfig, dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]: """Build AudioDataset from configuration. Args: cfg (omegaconf.DictConfig): Configuration. dataset_type: The type of dataset to create. Returns: dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split. """ dataloaders: dict = {} sample_rate = cfg.sample_rate channels = cfg.channels seed = cfg.seed max_sample_rate = cfg.datasource.max_sample_rate max_channels = cfg.datasource.max_channels assert cfg.dataset is not None, "Could not find dataset definition in config" dataset_cfg = dict_from_config(cfg.dataset) splits_cfg: dict = {} splits_cfg['train'] = dataset_cfg.pop('train') splits_cfg['valid'] = dataset_cfg.pop('valid') splits_cfg['evaluate'] = dataset_cfg.pop('evaluate') splits_cfg['generate'] = dataset_cfg.pop('generate') execute_only_stage = cfg.get('execute_only', None) for split, path in cfg.datasource.items(): if not isinstance(path, str): continue # skipping this as not a path if execute_only_stage is not None and split != execute_only_stage: continue logger.info(f"Loading audio data split {split}: {str(path)}") assert ( cfg.sample_rate <= max_sample_rate ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found." assert ( cfg.channels <= max_channels ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found." split_cfg = splits_cfg[split] split_kwargs = {k: v for k, v in split_cfg.items()} kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg kwargs['sample_rate'] = sample_rate kwargs['channels'] = channels if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch: kwargs['num_samples'] = ( flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch) num_samples = kwargs['num_samples'] shuffle = kwargs['shuffle'] return_info = kwargs.pop('return_info') batch_size = kwargs.pop('batch_size', None) num_workers = kwargs.pop('num_workers') if dataset_type == DatasetType.MUSIC: dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs) elif dataset_type == DatasetType.SOUND: dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs) elif dataset_type == DatasetType.AUDIO: dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs) else: raise ValueError(f"Dataset type is unsupported: {dataset_type}") loader = get_loader( dataset, num_samples, batch_size=batch_size, num_workers=num_workers, seed=seed, collate_fn=dataset.collater if return_info else None, shuffle=shuffle, ) dataloaders[split] = loader return dataloaders ================================================ FILE: audiocraft/solvers/compression.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import multiprocessing from pathlib import Path import typing as tp import flashy import omegaconf import torch from torch import nn from . import base, builders from .. import models, quantization from ..utils import checkpoint from ..utils.samples.manager import SampleManager from ..utils.utils import get_pool_executor logger = logging.getLogger(__name__) class CompressionSolver(base.StandardSolver): """Solver for compression task. The compression task combines a set of perceptual and objective losses to train an EncodecModel (composed of an encoder-decoder and a quantizer) to perform high fidelity audio reconstruction. """ def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) self.rng: torch.Generator # set at each epoch self.adv_losses = builders.get_adversarial_losses(self.cfg) self.aux_losses = nn.ModuleDict() self.info_losses = nn.ModuleDict() assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver." loss_weights = dict() for loss_name, weight in self.cfg.losses.items(): if loss_name in ['adv', 'feat']: for adv_name, _ in self.adv_losses.items(): loss_weights[f'{loss_name}_{adv_name}'] = weight elif weight > 0: self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg) loss_weights[loss_name] = weight else: self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg) self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer) self.register_stateful('adv_losses') @property def best_metric_name(self) -> tp.Optional[str]: # best model is the last for the compression model return None def build_model(self): """Instantiate model and optimizer.""" # Model and optimizer self.model = models.builders.get_compression_model(self.cfg).to(self.device) self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) self.register_stateful('model', 'optimizer') self.register_best_state('model') self.register_ema('model') def build_dataloaders(self): """Instantiate audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg) def show(self): """Show the compression model and employed adversarial loss.""" self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:") self.log_model_summary(self.model) self.logger.info("Adversarial loss:") self.log_model_summary(self.adv_losses) self.logger.info("Auxiliary losses:") self.logger.info(self.aux_losses) self.logger.info("Info losses:") self.logger.info(self.info_losses) def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): """Perform one training or valid step on a given batch.""" x = batch.to(self.device) y = x.clone() qres = self.model(x) assert isinstance(qres, quantization.QuantizedResult) y_pred = qres.x # Log bandwidth in kb/s metrics['bandwidth'] = qres.bandwidth.mean() if self.is_training: d_losses: dict = {} if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every: for adv_name, adversary in self.adv_losses.items(): disc_loss = adversary.train_adv(y_pred, y) d_losses[f'd_{adv_name}'] = disc_loss metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values()))) metrics.update(d_losses) balanced_losses: dict = {} other_losses: dict = {} # penalty from quantization if qres.penalty is not None and qres.penalty.requires_grad: other_losses['penalty'] = qres.penalty # penalty term from the quantizer # adversarial losses for adv_name, adversary in self.adv_losses.items(): adv_loss, feat_loss = adversary(y_pred, y) balanced_losses[f'adv_{adv_name}'] = adv_loss balanced_losses[f'feat_{adv_name}'] = feat_loss # auxiliary losses for loss_name, criterion in self.aux_losses.items(): loss = criterion(y_pred, y) balanced_losses[loss_name] = loss # weighted losses metrics.update(balanced_losses) metrics.update(other_losses) metrics.update(qres.metrics) if self.is_training: # backprop losses that are not handled by balancer other_loss = torch.tensor(0., device=self.device) if 'penalty' in other_losses: other_loss += other_losses['penalty'] if other_loss.requires_grad: other_loss.backward(retain_graph=True) ratio1 = sum(p.grad.data.norm(p=2).pow(2) for p in self.model.parameters() if p.grad is not None) assert isinstance(ratio1, torch.Tensor) metrics['ratio1'] = ratio1.sqrt() # balancer losses backward, returns effective training loss # with effective weights at the current batch. metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred) # add metrics corresponding to weight ratios metrics.update(self.balancer.metrics) ratio2 = sum(p.grad.data.norm(p=2).pow(2) for p in self.model.parameters() if p.grad is not None) assert isinstance(ratio2, torch.Tensor) metrics['ratio2'] = ratio2.sqrt() # optim flashy.distrib.sync_model(self.model) if self.cfg.optim.max_norm: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.optim.max_norm ) self.optimizer.step() self.optimizer.zero_grad() # informative losses only info_losses: dict = {} with torch.no_grad(): for loss_name, criterion in self.info_losses.items(): loss = criterion(y_pred, y) info_losses[loss_name] = loss metrics.update(info_losses) # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')] if len(adv_losses) > 0: metrics['adv'] = torch.sum(torch.stack(adv_losses)) feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')] if len(feat_losses) > 0: metrics['feat'] = torch.sum(torch.stack(feat_losses)) return metrics def run_epoch(self): # reset random seed at the beginning of the epoch self.rng = torch.Generator() self.rng.manual_seed(1234 + self.epoch) # run epoch super().run_epoch() def evaluate(self): """Evaluate stage. Runs audio reconstruction evaluation.""" self.model.eval() evaluate_stage_name = str(self.current_stage) loader = self.dataloaders['evaluate'] updates = len(loader) lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) average = flashy.averager() pendings = [] ctx = multiprocessing.get_context('spawn') with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool: for idx, batch in enumerate(lp): x = batch.to(self.device) with torch.no_grad(): qres = self.model(x) y_pred = qres.x.cpu() y = batch.cpu() # should already be on CPU but just in case pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)) metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates) for pending in metrics_lp: metrics = pending.result() metrics = average(metrics) metrics = flashy.distrib.average_metrics(metrics, len(loader)) return metrics def generate(self): """Generate stage.""" self.model.eval() sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True) generate_stage_name = str(self.current_stage) loader = self.dataloaders['generate'] updates = len(loader) lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) for batch in lp: reference, _ = batch reference = reference.to(self.device) with torch.no_grad(): qres = self.model(reference) assert isinstance(qres, quantization.QuantizedResult) reference = reference.cpu() estimate = qres.x.cpu() sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) flashy.distrib.barrier() def load_from_pretrained(self, name: str) -> dict: model = models.CompressionModel.get_pretrained(name) if isinstance(model, models.DAC): raise RuntimeError("Cannot fine tune a DAC model.") elif isinstance(model, models.HFEncodecCompressionModel): self.logger.warning('Trying to automatically convert a HuggingFace model ' 'to AudioCraft, this might fail!') state = model.model.state_dict() new_state = {} for k, v in state.items(): if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k: # We need to determine if this a convtr or a regular conv. layer = int(k.split('.')[2]) if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d): k = k.replace('.conv.', '.convtr.') k = k.replace('encoder.layers.', 'encoder.model.') k = k.replace('decoder.layers.', 'decoder.model.') k = k.replace('conv.', 'conv.conv.') k = k.replace('convtr.', 'convtr.convtr.') k = k.replace('quantizer.layers.', 'quantizer.vq.layers.') k = k.replace('.codebook.', '._codebook.') new_state[k] = v state = new_state elif isinstance(model, models.EncodecModel): state = model.state_dict() else: raise RuntimeError(f"Cannot fine tune model type {type(model)}.") return { 'best_state': {'model': state} } @staticmethod def model_from_checkpoint(checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: """Instantiate a CompressionModel from a given checkpoint path or dora sig. This method is a convenient endpoint to load a CompressionModel to use in other solvers. Args: checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. This also supports pre-trained models by using a path of the form //pretrained/NAME. See `model_from_pretrained` for a list of supported pretrained models. use_ema (bool): Use EMA variant of the model instead of the actual model. device (torch.device or str): Device on which the model is loaded. """ checkpoint_path = str(checkpoint_path) if checkpoint_path.startswith('//pretrained/'): name = checkpoint_path.split('/', 3)[-1] return models.CompressionModel.get_pretrained(name, device) logger = logging.getLogger(__name__) logger.info(f"Loading compression model from checkpoint: {checkpoint_path}") _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False) assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}" state = checkpoint.load_checkpoint(_checkpoint_path) assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}" cfg = state['xp.cfg'] cfg.device = device compression_model = models.builders.get_compression_model(cfg).to(device) assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" assert 'best_state' in state and state['best_state'] != {} assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix." compression_model.load_state_dict(state['best_state']['model']) compression_model.eval() logger.info("Compression model loaded!") return compression_model @staticmethod def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig, checkpoint_path: tp.Union[Path, str], device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel: """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig. Args: cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode. checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved. use_ema (bool): Use EMA variant of the model instead of the actual model. device (torch.device or str): Device on which the model is loaded. """ compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device) compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg) return compression_model def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict: """Audio reconstruction evaluation method that can be conveniently pickled.""" metrics = {} if cfg.evaluate.metrics.visqol: visqol = builders.get_visqol(cfg.metrics.visqol) metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate) sisnr = builders.get_loss('sisnr', cfg) metrics['sisnr'] = sisnr(y_pred, y) return metrics ================================================ FILE: audiocraft/solvers/diffusion.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import flashy import julius import omegaconf import torch import torch.nn.functional as F from . import builders from . import base from .. import models from ..modules.diffusion_schedule import NoiseSchedule from ..metrics import RelativeVolumeMel from ..models.builders import get_processor from ..utils.samples.manager import SampleManager from ..solvers.compression import CompressionSolver class PerStageMetrics: """Handle prompting the metrics per stage. It outputs the metrics per range of diffusion states. e.g. avg loss when t in [250, 500] """ def __init__(self, num_steps: int, num_stages: int = 4): self.num_steps = num_steps self.num_stages = num_stages def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): if type(step) is int: stage = int((step / self.num_steps) * self.num_stages) return {f"{name}_{stage}": loss for name, loss in losses.items()} elif type(step) is torch.Tensor: stage_tensor = ((step / self.num_steps) * self.num_stages).long() out: tp.Dict[str, float] = {} for stage_idx in range(self.num_stages): mask = (stage_tensor == stage_idx) N = mask.sum() stage_out = {} if N > 0: # pass if no elements in the stage for name, loss in losses.items(): stage_loss = (mask * loss).sum() / N stage_out[f"{name}_{stage_idx}"] = stage_loss out = {**out, **stage_out} return out class DataProcess: """Apply filtering or resampling. Args: initial_sr (int): Initial sample rate. target_sr (int): Target sample rate. use_resampling: Whether to use resampling or not. use_filter (bool): n_bands (int): Number of bands to consider. idx_band (int): device (torch.device or str): cutoffs (): boost (bool): """ def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, use_filter: bool = False, n_bands: int = 4, idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): """Apply filtering or resampling Args: initial_sr (int): sample rate of the dataset target_sr (int): sample rate after resampling use_resampling (bool): whether or not performs resampling use_filter (bool): when True filter the data to keep only one frequency band n_bands (int): Number of bands used cuts (none or list): The cutoff frequencies of the band filtering if None then we use mel scale bands. idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs boost (bool): make the data scale match our music dataset. """ assert idx_band < n_bands self.idx_band = idx_band if use_filter: if cutoffs is not None: self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) else: self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) self.use_filter = use_filter self.use_resampling = use_resampling self.target_sr = target_sr self.initial_sr = initial_sr self.boost = boost def process_data(self, x, metric=False): if x is None: return None if self.boost: x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) x * 0.22 if self.use_filter and not metric: x = self.filter(x)[self.idx_band] if self.use_resampling: x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) return x def inverse_process(self, x): """Upsampling only.""" if self.use_resampling: x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) return x class DiffusionSolver(base.StandardSolver): """Solver for compression task. The diffusion task allows for MultiBand diffusion model training. Args: cfg (DictConfig): Configuration. """ def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) self.cfg = cfg self.device = cfg.device self.sample_rate: int = self.cfg.sample_rate self.codec_model = CompressionSolver.model_from_checkpoint( cfg.compression_model_checkpoint, device=self.device) self.codec_model.set_num_codebooks(cfg.n_q) assert self.codec_model.sample_rate == self.cfg.sample_rate, ( f"Codec model sample rate is {self.codec_model.sample_rate} but " f"Solver sample rate is {self.cfg.sample_rate}." ) assert self.codec_model.sample_rate == self.sample_rate, \ f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ "don't match." self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) self.register_stateful('sample_processor') self.sample_processor.to(self.device) self.schedule = NoiseSchedule( **cfg.schedule, device=self.device, sample_processor=self.sample_processor) self.eval_metric: tp.Optional[torch.nn.Module] = None self.rvm = RelativeVolumeMel() self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, idx_band=cfg.filter.idx_band, device=self.device) @property def best_metric_name(self) -> tp.Optional[str]: if self._current_stage == "evaluate": return 'rvm' else: return 'loss' @torch.no_grad() def get_condition(self, wav: torch.Tensor) -> torch.Tensor: codes, scale = self.codec_model.encode(wav) assert scale is None, "Scaled compression models not supported." emb = self.codec_model.decode_latent(codes) return emb def build_model(self): """Build model and optimizer as well as optional Exponential Moving Average of the model. """ # Model and optimizer self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) self.register_stateful('model', 'optimizer') self.register_best_state('model') self.register_ema('model') def build_dataloaders(self): """Build audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg) def show(self): # TODO raise NotImplementedError() def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): """Perform one training or valid step on a given batch.""" x = batch.to(self.device) loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss condition = self.get_condition(x) # [bs, 128, T/hop, n_emb] sample = self.data_processor.process_data(x) input_, target, step = self.schedule.get_training_item(sample, tensor_step=self.cfg.schedule.variable_step_batch) out = self.model(input_, step, condition=condition).sample base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) loss = base_loss / reference_loss ** self.cfg.loss.norm_power if self.is_training: loss.mean().backward() flashy.distrib.sync_model(self.model) self.optimizer.step() self.optimizer.zero_grad() metrics = { 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), } metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) metrics.update({ 'std_in': input_.std(), 'std_out': out.std()}) return metrics def run_epoch(self): # reset random seed at the beginning of the epoch self.rng = torch.Generator() self.rng.manual_seed(1234 + self.epoch) self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) # run epoch super().run_epoch() def evaluate(self): """Evaluate stage. Runs audio reconstruction evaluation. """ self.model.eval() evaluate_stage_name = f'{self.current_stage}' loader = self.dataloaders['evaluate'] updates = len(loader) lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) metrics = {} n = 1 for idx, batch in enumerate(lp): x = batch.to(self.device) with torch.no_grad(): y_pred = self.regenerate(x) y_pred = y_pred.cpu() y = batch.cpu() # should already be on CPU but just in case rvm = self.rvm(y_pred, y) lp.update(**rvm) if len(metrics) == 0: metrics = rvm else: for key in rvm.keys(): metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) metrics = flashy.distrib.average_metrics(metrics) return metrics @torch.no_grad() def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): """Regenerate the given waveform.""" condition = self.get_condition(wav) initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes. result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, step_list=step_list) result = self.data_processor.inverse_process(result) return result def generate(self): """Generate stage.""" sample_manager = SampleManager(self.xp) self.model.eval() generate_stage_name = f'{self.current_stage}' loader = self.dataloaders['generate'] updates = len(loader) lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) for batch in lp: reference, _ = batch reference = reference.to(self.device) estimate = self.regenerate(reference) reference = reference.cpu() estimate = estimate.cpu() sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) flashy.distrib.barrier() ================================================ FILE: audiocraft/solvers/musicgen.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path import time import typing as tp import flashy import math import omegaconf import torch from torch.nn import functional as F from . import base, builders from .compression import CompressionSolver from .. import metrics as eval_metrics from .. import models from ..data.audio_dataset import AudioDataset from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo from ..data.audio_utils import normalize_audio from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition from ..utils.cache import CachedBatchWriter, CachedBatchLoader from ..utils.samples.manager import SampleManager from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once class MusicGenSolver(base.StandardSolver): """Solver for MusicGen training task. Used in: https://arxiv.org/abs/2306.05284 """ DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) # easier access to sampling parameters self.generation_params = { 'use_sampling': self.cfg.generate.lm.use_sampling, 'temp': self.cfg.generate.lm.temp, 'top_k': self.cfg.generate.lm.top_k, 'top_p': self.cfg.generate.lm.top_p, } self._best_metric_name: tp.Optional[str] = 'ce' self._cached_batch_writer = None self._cached_batch_loader = None if cfg.cache.path: if cfg.cache.write: self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path)) if self.cfg.cache.write_num_shards: self.logger.warning("Multiple shard cache, best_metric_name will be set to None.") self._best_metric_name = None else: self._cached_batch_loader = CachedBatchLoader( Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers, min_length=self.cfg.optim.updates_per_epoch or 1) self.dataloaders['original_train'] = self.dataloaders['train'] self.dataloaders['train'] = self._cached_batch_loader # type: ignore @staticmethod def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None, device: tp.Optional[str] = None, autocast: bool = True, batch_size: tp.Optional[int] = None, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, **kwargs): """Mostly a convenience function around magma.train.get_solver_from_sig, populating all the proper param, deactivating EMA, FSDP, loading the best state, basically all you need to get a solver ready to "play" with in single GPU mode and with minimal memory overhead. Args: sig (str): signature to load. dtype (str or None): potential dtype, as a string, i.e. 'float16'. device (str or None): potential device, as a string, i.e. 'cuda'. override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'. """ from audiocraft import train our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}} our_override_cfg['autocast'] = autocast if dtype is not None: our_override_cfg['dtype'] = dtype if device is not None: our_override_cfg['device'] = device if batch_size is not None: our_override_cfg['dataset'] = {'batch_size': batch_size} if override_cfg is None: override_cfg = {} override_cfg = omegaconf.OmegaConf.merge( omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore solver = train.get_solver_from_sig( sig, override_cfg=override_cfg, load_best=True, disable_fsdp=True, ignore_state_keys=['optimizer', 'ema'], **kwargs) solver.model.eval() return solver def get_formatter(self, stage_name: str) -> flashy.Formatter: return flashy.Formatter({ 'lr': '.2E', 'ce': '.3f', 'ppl': '.3f', 'grad_norm': '.3E', }, exclude_keys=['ce_q*', 'ppl_q*']) @property def best_metric_name(self) -> tp.Optional[str]: return self._best_metric_name def build_model(self) -> None: """Instantiate models and optimizer.""" # we can potentially not use all quantizers with which the EnCodec model was trained # (e.g. we trained the model with quantizers dropout) self.compression_model = CompressionSolver.wrapped_model_from_checkpoint( self.cfg, self.cfg.compression_model_checkpoint, device=self.device) assert self.compression_model.sample_rate == self.cfg.sample_rate, ( f"Compression model sample rate is {self.compression_model.sample_rate} but " f"Solver sample rate is {self.cfg.sample_rate}." ) # ensure we have matching configuration between LM and compression model assert self.cfg.transformer_lm.card == self.compression_model.cardinality, ( "Cardinalities of the LM and compression model don't match: ", f"LM cardinality is {self.cfg.transformer_lm.card} vs ", f"compression model cardinality is {self.compression_model.cardinality}" ) assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, ( "Numbers of codebooks of the LM and compression models don't match: ", f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ", f"compression model numer of codebooks is {self.compression_model.num_codebooks}" ) self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d", self.compression_model.num_codebooks, self.compression_model.cardinality, self.compression_model.frame_rate) # instantiate LM model self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device) if self.cfg.fsdp.use: assert not self.cfg.autocast, "Cannot use autocast with fsdp" self.model = self.wrap_with_fsdp(self.model) self.register_ema('model') # initialize optimization self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') self.register_best_state('model') self.autocast_dtype = { 'float16': torch.float16, 'bfloat16': torch.bfloat16 }[self.cfg.autocast_dtype] self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None if self.cfg.fsdp.use: need_scaler = self.cfg.fsdp.param_dtype == 'float16' else: need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16 if need_scaler: if self.cfg.fsdp.use: from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler self.scaler = ShardedGradScaler() # type: ignore else: self.scaler = torch.cuda.amp.GradScaler() self.register_stateful('scaler') def build_dataloaders(self) -> None: """Instantiate audio dataloaders for each stage.""" self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE) def show(self) -> None: """Show the compression model and LM model.""" self.logger.info("Compression model:") self.log_model_summary(self.compression_model) self.logger.info("LM model:") self.log_model_summary(self.model) def load_state_dict(self, state: dict) -> None: if 'condition_provider' in state: model_state = state['model'] condition_provider_state = state.pop('condition_provider') prefix = 'condition_provider.' for key, value in condition_provider_state.items(): key = prefix + key assert key not in model_state model_state[key] = value super().load_state_dict(state) def load_from_pretrained(self, name: str): # TODO: support native HF versions of MusicGen. lm_pkg = models.loaders.load_lm_model_ckpt(name) state: dict = { 'best_state': { 'model': lm_pkg['best_state'], }, } return state def _compute_cross_entropy( self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: """Compute cross entropy between multi-codebook targets and model's logits. The cross entropy is computed per codebook to provide codebook-level cross entropy. Valid timesteps for each of the codebook are pulled from the mask, where invalid timesteps are set to 0. Args: logits (torch.Tensor): Model's logits of shape [B, K, T, card]. targets (torch.Tensor): Target codes, of shape [B, K, T]. mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. Returns: ce (torch.Tensor): Cross entropy averaged over the codebooks ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). """ B, K, T = targets.shape assert logits.shape[:-1] == targets.shape assert mask.shape == targets.shape ce = torch.zeros([], device=targets.device) ce_per_codebook: tp.List[torch.Tensor] = [] for k in range(K): logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] ce_targets = targets_k[mask_k] ce_logits = logits_k[mask_k] q_ce = F.cross_entropy(ce_logits, ce_targets) ce += q_ce ce_per_codebook.append(q_ce.detach()) # average cross entropy across codebooks ce = ce / K return ce, ce_per_codebook @torch.no_grad() def _prepare_tokens_and_attributes( self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], check_synchronization_points: bool = False ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]: """Prepare input batchs for language model training. Args: batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T] and corresponding metadata as SegmentWithAttributes (with B items). check_synchronization_points (bool): Whether to check for synchronization points slowing down training. Returns: Condition tensors (dict[str, any]): Preprocessed condition attributes. Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s], with B the batch size, K the number of codebooks, T_s the token timesteps. Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. """ if self._cached_batch_loader is None or self.current_stage != "train": audio, infos = batch audio = audio.to(self.device) audio_tokens = None assert audio.size(0) == len(infos), ( f"Mismatch between number of items in audio batch ({audio.size(0)})", f" and in metadata ({len(infos)})" ) else: audio = None # In that case the batch will be a tuple coming from the _cached_batch_writer bit below. infos, = batch # type: ignore assert all([isinstance(info, AudioInfo) for info in infos]) assert all([info.audio_tokens is not None for info in infos]) # type: ignore audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore audio_tokens = audio_tokens.long() for info in infos: if isinstance(info, MusicInfo): # Careful here, if you want to use this condition_wav (e.b. chroma conditioning), # then you must be using the chroma cache! otherwise the code will try # to use this segment and fail (by that I mean you will see NaN everywhere). info.self_wav = WavCondition( torch.full([1, info.channels, info.total_frames], float('NaN')), length=torch.tensor([info.n_frames]), sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time]) dataset = get_dataset_from_loader(self.dataloaders['original_train']) assert isinstance(dataset, MusicDataset), type(dataset) if dataset.paraphraser is not None and info.description is not None: # Hackingly reapplying paraphraser when using cache. info.description = dataset.paraphraser.sample_paraphrase( info.meta.path, info.description) # prepare attributes attributes = [info.to_condition_attributes() for info in infos] attributes = self.model.cfg_dropout(attributes) attributes = self.model.att_dropout(attributes) tokenized = self.model.condition_provider.tokenize(attributes) # Now we should be synchronization free. if self.device == "cuda" and check_synchronization_points: torch.cuda.set_sync_debug_mode("warn") if audio_tokens is None: with torch.no_grad(): audio_tokens, scale = self.compression_model.encode(audio) assert scale is None, "Scaled compression model not supported with LM." with self.autocast: condition_tensors = self.model.condition_provider(tokenized) # create a padding mask to hold valid vs invalid positions padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device) # replace encodec tokens from padded audio with special_token_id if self.cfg.tokens.padding_with_special_token: audio_tokens = audio_tokens.clone() padding_mask = padding_mask.clone() token_sample_rate = self.compression_model.frame_rate B, K, T_s = audio_tokens.shape for i in range(B): n_samples = infos[i].n_frames audio_sample_rate = infos[i].sample_rate # take the last token generated from actual audio frames (non-padded audio) valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate) audio_tokens[i, :, valid_tokens:] = self.model.special_token_id padding_mask[i, :, valid_tokens:] = 0 if self.device == "cuda" and check_synchronization_points: torch.cuda.set_sync_debug_mode("default") if self._cached_batch_writer is not None and self.current_stage == 'train': assert self._cached_batch_loader is None assert audio_tokens is not None for info, one_audio_tokens in zip(infos, audio_tokens): assert isinstance(info, AudioInfo) if isinstance(info, MusicInfo): assert not info.joint_embed, "joint_embed and cache not supported yet." info.self_wav = None assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item() info.audio_tokens = one_audio_tokens.short().cpu() self._cached_batch_writer.save(infos) return condition_tensors, audio_tokens, padding_mask def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict: """Perform one training or valid step on a given batch.""" check_synchronization_points = idx == 1 and self.device == 'cuda' condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes( batch, check_synchronization_points) self.deadlock_detect.update('tokens_and_conditions') if check_synchronization_points: torch.cuda.set_sync_debug_mode('warn') with self.autocast: model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore logits = model_output.logits mask = padding_mask & model_output.mask ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) loss = ce self.deadlock_detect.update('loss') if check_synchronization_points: torch.cuda.set_sync_debug_mode('default') if self.is_training: metrics['lr'] = self.optimizer.param_groups[0]['lr'] if self.scaler is not None: loss = self.scaler.scale(loss) self.deadlock_detect.update('scale') if self.cfg.fsdp.use: loss.backward() flashy.distrib.average_tensors(self.model.buffers()) elif self.cfg.optim.eager_sync: with flashy.distrib.eager_sync_model(self.model): loss.backward() else: # this should always be slower but can be useful # for weird use cases like multiple backwards. loss.backward() flashy.distrib.sync_model(self.model) self.deadlock_detect.update('backward') if self.scaler is not None: self.scaler.unscale_(self.optimizer) if self.cfg.optim.max_norm: if self.cfg.fsdp.use: metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore else: metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.cfg.optim.max_norm ) if self.scaler is None: self.optimizer.step() else: self.scaler.step(self.optimizer) self.scaler.update() if self.lr_scheduler: self.lr_scheduler.step() self.optimizer.zero_grad() self.deadlock_detect.update('optim') if self.scaler is not None: scale = self.scaler.get_scale() metrics['grad_scale'] = scale if not loss.isfinite().all(): raise RuntimeError("Model probably diverged.") metrics['ce'] = ce metrics['ppl'] = torch.exp(ce) for k, ce_q in enumerate(ce_per_codebook): metrics[f'ce_q{k + 1}'] = ce_q metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) return metrics @torch.no_grad() def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], gen_duration: float, prompt_duration: tp.Optional[float] = None, remove_prompt: bool = False, **generation_params) -> dict: """Run generate step on a batch of optional audio tensor and corresponding attributes. Args: batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch. gen_duration (float): Target audio duration for the generation. prompt_duration (float, optional): Duration for the audio prompt to use for continuation. remove_prompt (bool, optional): Whether to remove the prompt from the generated audio. generation_params: Additional generation parameters. Returns: gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation and the prompt along with additional information. """ bench_start = time.time() audio, meta = batch assert audio.size(0) == len(meta), ( f"Mismatch between number of items in audio batch ({audio.size(0)})", f" and in metadata ({len(meta)})" ) # prepare attributes attributes = [x.to_condition_attributes() for x in meta] # TODO: Add dropout for chroma? # prepare audio prompt if prompt_duration is None: prompt_audio = None else: assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration" prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate) prompt_audio = audio[..., :prompt_audio_frames] # get audio tokens from compression model if prompt_audio is None or prompt_audio.nelement() == 0: num_samples = len(attributes) prompt_tokens = None else: num_samples = None prompt_audio = prompt_audio.to(self.device) prompt_tokens, scale = self.compression_model.encode(prompt_audio) assert scale is None, "Compression model in MusicGen should not require rescaling." # generate by sampling from the LM with self.autocast: total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate) gen_tokens = self.model.generate( prompt_tokens, attributes, max_gen_len=total_gen_len, num_samples=num_samples, **self.generation_params) # generate audio from tokens assert gen_tokens.dim() == 3 gen_audio = self.compression_model.decode(gen_tokens, None) bench_end = time.time() gen_outputs = { 'rtf': (bench_end - bench_start) / gen_duration, 'ref_audio': audio, 'gen_audio': gen_audio, 'gen_tokens': gen_tokens, 'prompt_audio': prompt_audio, 'prompt_tokens': prompt_tokens, } return gen_outputs def generate_audio(self) -> dict: """Audio generation stage.""" generate_stage_name = f'{self.current_stage}' sample_manager = SampleManager(self.xp) self.logger.info(f"Generating samples in {sample_manager.base_folder}") loader = self.dataloaders['generate'] updates = len(loader) lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) dataset = get_dataset_from_loader(loader) dataset_duration = dataset.segment_duration assert dataset_duration is not None assert isinstance(dataset, AudioDataset) target_duration = self.cfg.generate.lm.gen_duration prompt_duration = self.cfg.generate.lm.prompt_duration if target_duration is None: target_duration = dataset_duration if prompt_duration is None: prompt_duration = dataset_duration / 4 assert prompt_duration < dataset_duration, ( f"Specified prompt duration ({prompt_duration}s) is longer", f" than reference audio duration ({dataset_duration}s)" ) def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): hydrated_conditions = [] for sample in [x.to_condition_attributes() for x in meta]: cond_dict = {} for cond_type in sample.__annotations__.keys(): for cond_key, cond_val in getattr(sample, cond_type).items(): if cond_key not in self.model.condition_provider.conditioners.keys(): continue if is_jsonable(cond_val): cond_dict[cond_key] = cond_val elif isinstance(cond_val, WavCondition): cond_dict[cond_key] = cond_val.path elif isinstance(cond_val, JointEmbedCondition): cond_dict[cond_key] = cond_val.text # only support text at inference for now else: # if we reached this point, it is not clear how to log the condition # so we just log the type. cond_dict[cond_key] = str(type(cond_val)) continue hydrated_conditions.append(cond_dict) return hydrated_conditions metrics: dict = {} average = flashy.averager() for batch in lp: audio, meta = batch # metadata for sample manager hydrated_conditions = get_hydrated_conditions(meta) sample_generation_params = { **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()}, **self.generation_params } if self.cfg.generate.lm.unprompted_samples: if self.cfg.generate.lm.gen_gt_samples: # get the ground truth instead of generation self.logger.warn( "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true") gen_unprompted_audio = audio rtf = 1. else: gen_unprompted_outputs = self.run_generate_step( batch, gen_duration=target_duration, prompt_duration=prompt_duration, **self.generation_params) gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() rtf = gen_unprompted_outputs['rtf'] sample_manager.add_samples( gen_unprompted_audio, self.epoch, hydrated_conditions, ground_truth_wavs=audio, generation_args=sample_generation_params) if self.cfg.generate.lm.prompted_samples: gen_outputs = self.run_generate_step( batch, gen_duration=target_duration, prompt_duration=prompt_duration, **self.generation_params) gen_audio = gen_outputs['gen_audio'].cpu() prompt_audio = gen_outputs['prompt_audio'].cpu() sample_manager.add_samples( gen_audio, self.epoch, hydrated_conditions, prompt_wavs=prompt_audio, ground_truth_wavs=audio, generation_args=sample_generation_params) metrics['rtf'] = rtf metrics = average(metrics) flashy.distrib.barrier() return metrics def generate(self) -> dict: """Generate stage.""" self.model.eval() with torch.no_grad(): return self.generate_audio() def run_epoch(self): if self.cfg.cache.write: if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard: return super().run_epoch() def train(self): """Train stage. """ if self._cached_batch_writer is not None: self._cached_batch_writer.start_epoch(self.epoch) if self._cached_batch_loader is None: dataset = get_dataset_from_loader(self.dataloaders['train']) assert isinstance(dataset, AudioDataset) dataset.current_epoch = self.epoch else: self._cached_batch_loader.start_epoch(self.epoch) return super().train() def evaluate_audio_generation(self) -> dict: """Evaluate audio generation with off-the-shelf metrics.""" evaluate_stage_name = f'{self.current_stage}_generation' # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None should_run_eval = False eval_chroma_wavs: tp.Optional[torch.Tensor] = None if self.cfg.evaluate.metrics.fad: fad = builders.get_fad(self.cfg.metrics.fad).to(self.device) should_run_eval = True if self.cfg.evaluate.metrics.kld: kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device) should_run_eval = True if self.cfg.evaluate.metrics.text_consistency: text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device) should_run_eval = True if self.cfg.evaluate.metrics.chroma_cosine: chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device) # if we have predefind wavs for chroma we should purge them for computing the cosine metric has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \ self.model.condition_provider.conditioners['self_wav'].has_eval_wavs() if has_predefined_eval_chromas: warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! " 'Resetting eval chromas to None for evaluation.') eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore should_run_eval = True def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor: audio_tokens, scale = self.compression_model.encode(audio.to(self.device)) compressed_audio = self.compression_model.decode(audio_tokens, scale) return compressed_audio[..., :audio.shape[-1]] metrics: dict = {} if should_run_eval: loader = self.dataloaders['evaluate'] updates = len(loader) lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates) average = flashy.averager() dataset = get_dataset_from_loader(loader) assert isinstance(dataset, AudioDataset) self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples") for idx, batch in enumerate(lp): audio, meta = batch assert all([self.cfg.sample_rate == m.sample_rate for m in meta]) target_duration = audio.shape[-1] / self.cfg.sample_rate if self.cfg.evaluate.fixed_generation_duration: target_duration = self.cfg.evaluate.fixed_generation_duration gen_outputs = self.run_generate_step( batch, gen_duration=target_duration, **self.generation_params ) y_pred = gen_outputs['gen_audio'].detach() y_pred = y_pred[..., :audio.shape[-1]] normalize_kwargs = dict(self.cfg.generate.audio) normalize_kwargs.pop('format', None) y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu() y = audio.cpu() # should already be on CPU but just in case sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta] if fad is not None: if self.cfg.metrics.fad.use_gt: y_pred = get_compressed_audio(y).cpu() fad.update(y_pred, y, sizes, sample_rates, audio_stems) if kldiv is not None: if self.cfg.metrics.kld.use_gt: y_pred = get_compressed_audio(y).cpu() kldiv.update(y_pred, y, sizes, sample_rates) if text_consistency is not None: texts = [m.description for m in meta] if self.cfg.metrics.text_consistency.use_gt: y_pred = y text_consistency.update(y_pred, texts, sizes, sample_rates) if chroma_cosine is not None: if self.cfg.metrics.chroma_cosine.use_gt: y_pred = get_compressed_audio(y).cpu() chroma_cosine.update(y_pred, y, sizes, sample_rates) # restore chroma conditioner's eval chroma wavs if eval_chroma_wavs is not None: self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs) flashy.distrib.barrier() if fad is not None: metrics['fad'] = fad.compute() if kldiv is not None: kld_metrics = kldiv.compute() metrics.update(kld_metrics) if text_consistency is not None: metrics['text_consistency'] = text_consistency.compute() if chroma_cosine is not None: metrics['chroma_cosine'] = chroma_cosine.compute() metrics = average(metrics) metrics = flashy.distrib.average_metrics(metrics, len(loader)) return metrics def evaluate(self) -> dict: """Evaluate stage.""" self.model.eval() with torch.no_grad(): metrics: dict = {} if self.cfg.evaluate.metrics.base: metrics.update(self.common_train_valid('evaluate')) gen_metrics = self.evaluate_audio_generation() return {**metrics, **gen_metrics} ================================================ FILE: audiocraft/train.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Entry point for dora to launch solvers for running training loops. See more info on how to use dora: https://github.com/facebookresearch/dora """ import logging import multiprocessing import os import sys import typing as tp from dora import git_save, hydra_main, XP import flashy import hydra import omegaconf from .environment import AudioCraftEnvironment from .utils.cluster import get_slurm_parameters logger = logging.getLogger(__name__) def resolve_config_dset_paths(cfg): """Enable Dora to load manifest from git clone repository.""" # manifest files for the different splits for key, value in cfg.datasource.items(): if isinstance(value, str): cfg.datasource[key] = git_save.to_absolute_path(value) def get_solver(cfg): from . import solvers # Convert batch size to batch size for each GPU assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0 cfg.dataset.batch_size //= flashy.distrib.world_size() for split in ['train', 'valid', 'evaluate', 'generate']: if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'): assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0 cfg.dataset[split].batch_size //= flashy.distrib.world_size() resolve_config_dset_paths(cfg) solver = solvers.get_solver(cfg) return solver def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None, restore: bool = True, load_best: bool = True, ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True): """Given a XP, return the Solver object. Args: xp (XP): Dora experiment for which to retrieve the solver. override_cfg (dict or None): If not None, should be a dict used to override some values in the config of `xp`. This will not impact the XP signature or folder. The format is different than the one used in Dora grids, nested keys should actually be nested dicts, not flattened, e.g. `{'optim': {'batch_size': 32}}`. restore (bool): If `True` (the default), restore state from the last checkpoint. load_best (bool): If `True` (the default), load the best state from the checkpoint. ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`. disable_fsdp (bool): if True, disables FSDP entirely. This will also automatically skip loading the EMA. For solver specific state sources, like the optimizer, you might want to use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`. """ logger.info(f"Loading solver from XP {xp.sig}. " f"Overrides used: {xp.argv}") cfg = xp.cfg if override_cfg is not None: cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg)) if disable_fsdp and cfg.fsdp.use: cfg.fsdp.use = False assert load_best is True # ignoring some keys that were FSDP sharded like model, ema, and best_state. # fsdp_best_state will be used in that case. When using a specific solver, # one is responsible for adding the relevant keys, e.g. 'optimizer'. # We could make something to automatically register those inside the solver, but that # seem overkill at this point. ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state'] try: with xp.enter(): solver = get_solver(cfg) if restore: solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys) return solver finally: hydra.core.global_hydra.GlobalHydra.instance().clear() def get_solver_from_sig(sig: str, *args, **kwargs): """Return Solver object from Dora signature, i.e. to play with it from a notebook. See `get_solver_from_xp` for more information. """ xp = main.get_xp_from_sig(sig) return get_solver_from_xp(xp, *args, **kwargs) def init_seed_and_system(cfg): import numpy as np import torch import random from audiocraft.modules.transformer import set_efficient_attention_backend multiprocessing.set_start_method(cfg.mp_start_method) logger.debug('Setting mp start method to %s', cfg.mp_start_method) random.seed(cfg.seed) np.random.seed(cfg.seed) # torch also initialize cuda seed if available torch.manual_seed(cfg.seed) torch.set_num_threads(cfg.num_threads) os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads) os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads) logger.debug('Setting num threads to %d', cfg.num_threads) set_efficient_attention_backend(cfg.efficient_attention_backend) logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend) @hydra_main(config_path='../config', config_name='config', version_base='1.1') def main(cfg): init_seed_and_system(cfg) # Setup logging both to XP specific folder, and to stderr. log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}' flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name) # Initialize distributed training, no need to specify anything when using Dora. flashy.distrib.init() solver = get_solver(cfg) if cfg.show: solver.show() return if cfg.execute_only: assert cfg.execute_inplace or cfg.continue_from is not None, \ "Please explicitly specify the checkpoint to continue from with continue_from= " + \ "when running with execute_only or set execute_inplace to True." solver.restore(replay_metrics=False) # load checkpoint solver.run_one_stage(cfg.execute_only) return return solver.run() main.dora.dir = AudioCraftEnvironment.get_dora_dir() main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm) if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK): print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr) main.dora.shared = None if __name__ == '__main__': main() ================================================ FILE: audiocraft/utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Utilities.""" ================================================ FILE: audiocraft/utils/autocast.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch class TorchAutocast: """TorchAutocast utility class. Allows you to enable and disable autocast. This is specially useful when dealing with different architectures and clusters with different levels of support. Args: enabled (bool): Whether to enable torch.autocast or not. args: Additional args for torch.autocast. kwargs: Additional kwargs for torch.autocast """ def __init__(self, enabled: bool, *args, **kwargs): self.autocast = torch.autocast(*args, **kwargs) if enabled else None def __enter__(self): if self.autocast is None: return try: self.autocast.__enter__() except RuntimeError: device = self.autocast.device dtype = self.autocast.fast_dtype raise RuntimeError( f"There was an error autocasting with dtype={dtype} device={device}\n" "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" ) def __exit__(self, *args, **kwargs): if self.autocast is None: return self.autocast.__exit__(*args, **kwargs) ================================================ FILE: audiocraft/utils/best_state.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import defaultdict import logging import typing as tp import flashy import torch from ..optim import ModuleDictEMA from .utils import copy_state logger = logging.getLogger(__name__) class BestStateDictManager(flashy.state.StateDictSource): """BestStateDictManager maintains a copy of best state_dict() for registered sources. BestStateDictManager has two main attributes: states (dict): State dict of the registered StateDictSource. param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources. When registering new sources, the BestStateDictManager will ensure two conflicting sources between ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about what to consider for best state. Args: device (torch.device or str): Device on which we keep the copy. dtype (torch.dtype): Data type for the state parameters. """ def __init__(self, device: tp.Union[torch.device, str] = 'cpu', dtype: tp.Optional[torch.dtype] = None): self.device = device self.states: dict = {} self.param_ids: dict = defaultdict(dict) self.dtype = dtype def _get_parameter_ids(self, state_dict): return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)} def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): for registered_name, registered_param_ids in self.param_ids.items(): if registered_name != name: overlap = set.intersection(registered_param_ids.keys(), param_ids.keys()) assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters" f" in {name} and already registered {registered_name}: {' '.join(overlap)}" def update(self, name: str, source: flashy.state.StateDictSource): if name not in self.states: raise ValueError(f"{name} missing from registered states.") self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) def register(self, name: str, source: flashy.state.StateDictSource): if name in self.states: raise ValueError(f"{name} already present in states.") # Registering parameter ids for EMA and non-EMA states allows us to check that # there is no overlap that would create ambiguity about how to handle the best state param_ids = self._get_parameter_ids(source.state_dict()) if isinstance(source, ModuleDictEMA): logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params") self._validate_no_parameter_ids_overlap(name, param_ids) self.param_ids[name] = param_ids else: logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params") self._validate_no_parameter_ids_overlap('base', param_ids) self.param_ids['base'].update(param_ids) # Register state self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) def state_dict(self) -> flashy.state.StateDict: return self.states def load_state_dict(self, state: flashy.state.StateDict): for name, sub_state in state.items(): for k, v in sub_state.items(): self.states[name][k].copy_(v) ================================================ FILE: audiocraft/utils/cache.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from concurrent.futures import ThreadPoolExecutor from collections import deque from functools import partial from hashlib import sha1 import logging from pathlib import Path import sys import typing as tp import zipfile import flashy import torch logger = logging.getLogger(__name__) def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor: """Utility function for the EmbeddingCache, returning the full embedding without any chunking. This method can be used in case there is no need in extracting a chunk of the full embedding read from the cache. Args: full_embed (torch.Tensor): The full embedding. x (any): Batch object from which the full embedding is derived. idx (torch.Tensor): Index of object to consider in the batch object. Returns: full_embed (torch.Tensor): The full embedding """ return full_embed.to(device) class EmbeddingCache: """Cache around embeddings computation for faster execution. The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API to retrieve the pre-computed embeddings on full inputs and extract only a given chunk using a user-provided function. When the cache is warm (all embeddings are pre-computed), the EmbeddingCache allows for faster training as it removes the need of computing the embeddings. Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint and synchronization points in the forward calls. Args: cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk. device (str or torch.device): Device on which the embedding is returned. compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute the embedding from a given object and path. This user provided function can compute the embedding from the provided object or using the provided path as entry point. The last parameter specify the index corresponding to the current embedding in the object that can represent batch metadata. extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract the desired embedding chunk from the full embedding loaded from the cache. The last parameter specify the index corresponding to the current embedding in the object that can represent batch metadata. If not specified, will return the full embedding unmodified. """ def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, torch.device], compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor], extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None): self.cache_path = Path(cache_path) self.device = device self._compute_embed_fn = compute_embed_fn self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor] if extract_embed_fn is not None: self._extract_embed_fn = extract_embed_fn else: self._extract_embed_fn = partial(get_full_embed, device=device) if self.cache_path is not None: self.cache_path.mkdir(exist_ok=True, parents=True) logger.info(f"Cache instantiated at: {self.cache_path}") self.pool = ThreadPoolExecutor(8) self.pool.__enter__() self._current_batch_cache: dict = {} self._memory_cache: dict = {} def _get_cache_path(self, path: tp.Union[Path, str]): """Get cache path for the given file path.""" sig = sha1(str(path).encode()).hexdigest() return self.cache_path / sig @staticmethod def _get_full_embed_from_cache(cache: Path): """Loads full pre-computed embedding from the cache.""" try: embed = torch.load(cache, 'cpu') except Exception as exc: logger.error("Error loading %s: %r", cache, exc) embed = None return embed def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor: """Get embedding from cache, computing and storing it to cache if not already cached. The EmbeddingCache first tries to load the embedding from the in-memory cache containing the pre-computed chunks populated through `populate_embed_cache`. If not found, the full embedding is computed and stored on disk to be later accessed to populate the in-memory cache, and the desired embedding chunk is extracted and returned. Args: paths (list[Path or str]): List of paths from where the embeddings can be loaded. x (any): Object from which the embedding is extracted. """ embeds = [] for idx, path in enumerate(paths): cache = self._get_cache_path(path) if cache in self._current_batch_cache: embed = self._current_batch_cache[cache] else: full_embed = self._compute_embed_fn(path, x, idx) try: with flashy.utils.write_and_rename(cache, pid=True) as f: torch.save(full_embed.cpu(), f) except Exception as exc: logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc) else: logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape) embed = self._extract_embed_fn(full_embed, x, idx) embeds.append(embed) embed = torch.stack(embeds, dim=0) return embed def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None: """Populate in-memory caches for embeddings reading from the embeddings stored on disk. The in-memory caches consist in a cache for the full embedding and another cache for the final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings and reduce the IO footprint and synchronization points during forward passes. Args: paths (list[Path]): List of paths from where the embeddings can be loaded. x (any): Object from which the embedding is extracted. """ self._current_batch_cache.clear() if self.cache_path is not None: futures: list = [] for path in paths: assert path is not None, "Path is required for computation from cache" cache = self._get_cache_path(path) if cache in self._memory_cache or not cache.exists(): futures.append(None) else: futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache)) for idx, (path, future) in enumerate(zip(paths, futures)): assert path is not None cache = self._get_cache_path(path) full_embed = None if future is None: if cache in self._memory_cache: full_embed = self._memory_cache[cache] else: full_embed = future.result() if full_embed is not None: self._memory_cache[cache] = full_embed full_embed = full_embed.to(self.device) if full_embed is not None: embed = self._extract_embed_fn(full_embed, x, idx) self._current_batch_cache[cache] = embed class CachedBatchWriter: """Write pre computed caches for mini batches. This can make loading a lot more efficient depending on your filesystem. Args: cache_folder (Path): folder in which the cached minibatches will be stored. Inside cache folder, the structure is the following: `epoch_number / update_number.zip` And the zip file contains one entry per batch item. It is possible to use the cache with a batch size smaller than created with but obviously not larger. Make sure to call the `start_epoch(epoch)` method for indicating changes of epochs. See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py` for an example of how to warmup the cache. """ def __init__(self, cache_folder: Path): self.cache_folder = cache_folder self._current_epoch: tp.Optional[int] = None self._current_index = 0 def start_epoch(self, epoch: int): """Call at the beginning of each epoch. """ self._current_epoch = epoch self._current_index = 0 self._zip_path.parent.mkdir(exist_ok=True, parents=True) @staticmethod def _get_zip_path(cache_folder: Path, epoch: int, index: int): return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip" @property def _zip_path(self): assert self._current_epoch is not None return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index) def save(self, *content): """Save one mini batch. This function is distributed-aware and will automatically merge all the items from the different workers. """ all_contents = [] for rank in range(flashy.distrib.world_size()): their_content = flashy.distrib.broadcast_object(content, src=rank) all_contents.append(their_content) if flashy.distrib.is_rank_zero(): idx = 0 with flashy.utils.write_and_rename(self._zip_path) as tmp: with zipfile.ZipFile(tmp, 'w') as zf: for content in all_contents: for vals in zip(*content): with zf.open(f'{idx}', 'w') as f: # type: ignore torch.save(vals, f) idx += 1 flashy.distrib.barrier() self._current_index += 1 class CachedBatchLoader: """Loader for cached mini-batches dumped with `CachedBatchWriter`. Args: cache_folder (Path): folder in which the cached minibatches are stored. batch_size (int): batch size (per GPU) expected. num_workers (int): number of workers to use for loading. min_length (int): minimum expected length for each epoch. If some mini-batches are missing, and error is raised. This is iterable just like a regular DataLoader. """ def __init__(self, cache_folder: Path, batch_size: int, num_workers: int = 10, min_length: int = 1): self.cache_folder = cache_folder self.batch_size = batch_size self.num_workers = num_workers self.min_length = min_length self._current_epoch: tp.Optional[int] = None self.sampler = None # for compatibility with the regular DataLoader def __len__(self): path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent return len([p for p in path.iterdir() if p.suffix == ".zip"]) def start_epoch(self, epoch: int): """Call at the beginning of each epoch. """ self._current_epoch = epoch def _zip_path(self, index: int): assert self._current_epoch is not None return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index) def _load_one(self, index: int): zip_path = self._zip_path(index) if not zip_path.exists(): if index < self.min_length: raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist") return None mode = "rb" if sys.version_info >= (3, 9) else "r" try: with zipfile.ZipFile(zip_path, 'r') as zf: rank = flashy.distrib.rank() world_size = flashy.distrib.world_size() root = zipfile.Path(zf) items = list(root.iterdir()) total_batch_size = self.batch_size * world_size if len(items) < total_batch_size: raise RuntimeError( f"The cache can handle a max batch size of {len(items)}, " f"but {total_batch_size} is needed.") start = rank * self.batch_size items = items[start: start + self.batch_size] assert len(items) == self.batch_size entries = [] entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore transposed = zip(*entries) out = [] for part in transposed: assert len(part) > 0 if isinstance(part[0], torch.Tensor): out.append(torch.stack(part)) else: out.append(part) return out except Exception: logger.error("Error when reading zip path %s", zip_path) raise def __iter__(self): """This will yields tuples, exactly as provided to the `CachedBatchWriter.save` method. """ pool = ThreadPoolExecutor(self.num_workers) next_index = 0 queue = deque() def _get_next(): nonlocal next_index r = queue.popleft().result() if r is None: return None else: queue.append(pool.submit(self._load_one, next_index)) next_index += 1 return r with pool: # fill the buffer of fetching jobs. for _ in range(2 * self.num_workers): queue.append(pool.submit(self._load_one, next_index)) next_index += 1 while True: batch = _get_next() if batch is None: return yield batch ================================================ FILE: audiocraft/utils/checkpoint.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from enum import Enum import logging from pathlib import Path import re import typing as tp import flashy import torch from ..environment import AudioCraftEnvironment logger = logging.getLogger(__name__) class CheckpointSource(Enum): CURRENT_XP = "current_xp" PRETRAINED = "pretrained" OTHER = "other" def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: """Checkpoint name formatted for all use in AudioCraft codebase and has the following format: `checkpoint_.th(.)`. By convention, name is expected to be empty for last checkpoint, 'best' for the best checkpoint or the epoch number. Args: name (str, optional): Name suffix for the checkpoint file stem. rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. use_fsdp (bool): Whether the calling solver relies on FSDP. Returns: str: The checkpoint name. """ suffix = '' if rank is None: rank = flashy.distrib.rank() if rank > 0 and use_fsdp: suffix = '.' + str(rank) name_part = '' if name is not None: name_part = f'_{name}' return f'checkpoint{name_part}.th{suffix}' def is_sharded_checkpoint(path: Path) -> bool: """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" return re.search(r'\.th\.\d+$', path.name) is not None def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, use_fsdp: bool = False) -> tp.Optional[Path]: """Resolve a given checkpoint path for a provided dora sig or path. Args: sig_or_path (Path or str): Checkpoint path or dora signature. name (str, optional): Name suffix for the checkpoint file stem. rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. use_fsdp (bool): Whether the calling solver relies on FSDP. Returns: Path, optional: Resolved checkpoint path, if it exists. """ from audiocraft import train xps_root = train.main.dora.dir / 'xps' sig_or_path = str(sig_or_path) if sig_or_path.startswith('//sig/'): sig = sig_or_path[len('//sig/'):] path = xps_root / sig else: path = Path(sig_or_path) path = AudioCraftEnvironment.resolve_reference_path(path) if path.is_dir(): path = path / checkpoint_name(name, use_fsdp=use_fsdp) if path.exists(): return path else: return None def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: """Load state from checkpoints at the specified checkpoint path.""" if is_sharded: rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) if rank0_checkpoint_path.exists(): check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) state = torch.load(checkpoint_path, 'cpu') logger.info("Checkpoint loaded from %s", checkpoint_path) return state def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: """Save state to disk to the specified checkpoint_path.""" _safe_save_checkpoint(state, checkpoint_path, is_sharded) logger.info("Checkpoint saved to %s", checkpoint_path) def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: """Flush checkpoints to only keep last N checkpoints.""" if keep_last is None or keep_last <= 0: return checkpoint_dir = checkpoint_path.parent suffix = '' if flashy.distrib.rank() > 0: suffix = f'.{flashy.distrib.rank()}' checkpoint_files_with_epoch = [] for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] if epoch_part.isdigit(): checkpoint_files_with_epoch.append((path, int(epoch_part))) checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] total_to_flush = max(0, len(checkpoint_files) - keep_last) files_to_flush = checkpoint_files[:total_to_flush] for path in files_to_flush: logger.debug("Removing checkpoint: %s", str(path)) path.unlink(missing_ok=True) def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: """Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" # Finish the work of a previous run that got interrupted while dumping. old_path = Path(str(checkpoint_path) + '.old') if old_path.exists(): raise RuntimeError( f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") token = Path(str(rank0_checkpoint_path) + '.tmp.done') tmp_path = Path(str(checkpoint_path) + '.tmp') if token.exists(): if tmp_path.exists(): tmp_path.rename(checkpoint_path) flashy.distrib.barrier() if flashy.distrib.is_rank_zero() and token.exists(): token.unlink() def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: """Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" def _barrier_if_sharded(): if is_sharded: flashy.distrib.barrier() if flashy.distrib.is_rank_zero(): token = Path(str(checkpoint_path) + '.tmp.done') if token.exists(): token.unlink() _barrier_if_sharded() with flashy.utils.write_and_rename(checkpoint_path) as f: torch.save(state, f) _barrier_if_sharded() if flashy.distrib.is_rank_zero(): token.touch() _barrier_if_sharded() _barrier_if_sharded() if flashy.distrib.rank() == 0: token.unlink() ================================================ FILE: audiocraft/utils/cluster.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Utility functions for SLURM configuration and cluster settings. """ from enum import Enum import os import socket import typing as tp import omegaconf class ClusterType(Enum): AWS = "aws" FAIR = "fair" RSC = "rsc" LOCAL_DARWIN = "darwin" DEFAULT = "default" # used for any other cluster. def _guess_cluster_type() -> ClusterType: uname = os.uname() fqdn = socket.getfqdn() if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): return ClusterType.AWS if fqdn.endswith(".fair"): return ClusterType.FAIR if fqdn.endswith(".facebook.com"): return ClusterType.RSC if uname.sysname == "Darwin": return ClusterType.LOCAL_DARWIN return ClusterType.DEFAULT def get_cluster_type( cluster_type: tp.Optional[ClusterType] = None, ) -> tp.Optional[ClusterType]: if cluster_type is None: return _guess_cluster_type() return cluster_type def get_slurm_parameters( cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None ) -> omegaconf.DictConfig: """Update SLURM parameters in configuration based on cluster type. If the cluster type is not specify, it infers it automatically. """ from ..environment import AudioCraftEnvironment cluster_type = get_cluster_type(cluster_type) # apply cluster-specific adjustments if cluster_type == ClusterType.AWS: cfg["mem_per_gpu"] = None cfg["constraint"] = None cfg["setup"] = [] elif cluster_type == ClusterType.RSC: cfg["mem_per_gpu"] = None cfg["setup"] = [] cfg["constraint"] = None cfg["partition"] = "learn" slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() if slurm_exclude is not None: cfg["exclude"] = slurm_exclude return cfg ================================================ FILE: audiocraft/utils/deadlock.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import os from queue import Queue, Empty import signal import sys import threading import traceback logger = logging.getLogger(__name__) class DeadlockDetect: def __init__(self, use: bool = False, timeout: float = 120.): self.use = use self.timeout = timeout self._queue: Queue = Queue() def update(self, stage: str): if self.use: self._queue.put(stage) def __enter__(self): if self.use: self._thread = threading.Thread(target=self._detector_thread) self._thread.start() def __exit__(self, exc_type, exc_val, exc_tb): if self.use: self._queue.put(None) self._thread.join() def _detector_thread(self): logger.debug("Deadlock detector started") last_stage = "init" while True: try: stage = self._queue.get(timeout=self.timeout) except Empty: break if stage is None: logger.debug("Exiting deadlock detector thread") return else: last_stage = stage logger.error("Deadlock detector timed out, last stage was %s", last_stage) for th in threading.enumerate(): print(th, file=sys.stderr) traceback.print_stack(sys._current_frames()[th.ident]) print(file=sys.stderr) sys.stdout.flush() sys.stderr.flush() os.kill(os.getpid(), signal.SIGKILL) ================================================ FILE: audiocraft/utils/export.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Utility to export a training checkpoint to a lightweight release checkpoint. """ from pathlib import Path import typing as tp from omegaconf import OmegaConf import torch from audiocraft import __version__ def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): """Export only the best state from the given EnCodec checkpoint. This should be used if you trained your own EnCodec model. """ pkg = torch.load(checkpoint_path, 'cpu') new_pkg = { 'best_state': pkg['best_state']['model'], 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), 'version': __version__, 'exported': True, } Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): """Export a compression model (potentially EnCodec) from a pretrained model. This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. Do not include the //pretrained/ prefix. For instance if you trained a model with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. In that case, this will not actually include a copy of the model, simply the reference to the model used. """ if Path(pretrained_encodec).exists(): pkg = torch.load(pretrained_encodec) assert 'best_state' in pkg assert 'xp.cfg' in pkg assert 'version' in pkg assert 'exported' in pkg else: pkg = { 'pretrained': pretrained_encodec, 'exported': True, 'version': __version__, } Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(pkg, out_file) def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): """Export only the best state from the given MusicGen or AudioGen checkpoint. """ pkg = torch.load(checkpoint_path, 'cpu') if pkg['fsdp_best_state']: best_state = pkg['fsdp_best_state']['model'] else: assert pkg['best_state'] best_state = pkg['best_state']['model'] new_pkg = { 'best_state': best_state, 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), 'version': __version__, 'exported': True, } Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file ================================================ FILE: audiocraft/utils/export_legacy.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Legacy functions used at the time of the first release, kept for referencd. """ from pathlib import Path import typing as tp from omegaconf import OmegaConf, DictConfig import torch def _clean_lm_cfg(cfg: DictConfig): OmegaConf.set_struct(cfg, False) # This used to be set automatically in the LM solver, need a more robust solution # for the future. cfg['transformer_lm']['card'] = 2048 cfg['transformer_lm']['n_q'] = 4 # Experimental params no longer supported. bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] for name in bad_params: del cfg['transformer_lm'][name] OmegaConf.set_struct(cfg, True) return cfg def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): sig = Path(checkpoint_path).parent.name assert len(sig) == 8, "Not a valid Dora signature" pkg = torch.load(checkpoint_path, 'cpu') new_pkg = { 'best_state': pkg['ema']['state']['model'], 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), } out_file = Path(out_folder) / f'{sig}.th' torch.save(new_pkg, out_file) return out_file def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): sig = Path(checkpoint_path).parent.name assert len(sig) == 8, "Not a valid Dora signature" pkg = torch.load(checkpoint_path, 'cpu') new_pkg = { 'best_state': pkg['fsdp_best_state']['model'], 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) } out_file = Path(out_folder) / f'{sig}.th' torch.save(new_pkg, out_file) return out_file ================================================ FILE: audiocraft/utils/notebook.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. try: import IPython.display as ipd # type: ignore except ImportError: # Note in a notebook... pass import torch def display_audio(samples: torch.Tensor, sample_rate: int): """Renders an audio player for the given audio samples. Args: samples (torch.Tensor): a Tensor of decoded audio samples with shapes [B, C, T] or [C, T] sample_rate (int): sample rate audio should be displayed with. """ assert samples.dim() == 2 or samples.dim() == 3 samples = samples.detach().cpu() if samples.dim() == 2: samples = samples[None, ...] for audio in samples: ipd.display(ipd.Audio(audio, rate=sample_rate)) ================================================ FILE: audiocraft/utils/profiler.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import logging import typing as tp import dora import torch logger = logging.getLogger(__name__) class Profiler: """Context manager wrapper for xformers profiler. """ def __init__(self, module: torch.nn.Module, enabled: bool = False): self.profiler: tp.Optional[tp.Any] = None if enabled: from xformers.profiler import profile output_dir = dora.get_xp().folder / 'profiler_data' logger.info("Profiling activated, results with be saved to %s", output_dir) self.profiler = profile(output_dir=output_dir, module=module) def step(self): if self.profiler is not None: self.profiler.step() # type: ignore def __enter__(self): if self.profiler is not None: return self.profiler.__enter__() # type: ignore def __exit__(self, exc_type, exc_value, exc_tb): if self.profiler is not None: return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore ================================================ FILE: audiocraft/utils/samples/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: audiocraft/utils/samples/manager.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ API that can manage the storage and retrieval of generated samples produced by experiments. It offers the following benefits: * Samples are stored in a consistent way across epoch * Metadata about the samples can be stored and retrieved * Can retrieve audio * Identifiers are reliable and deterministic for prompted and conditioned samples * Can request the samples for multiple XPs, grouped by sample identifier * For no-input samples (not prompt and no conditions), samples across XPs are matched by sorting their identifiers """ from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass from functools import lru_cache import hashlib import json import logging from pathlib import Path import re import typing as tp import unicodedata import uuid import dora import torch from ...data.audio import audio_read, audio_write logger = logging.getLogger(__name__) @dataclass class ReferenceSample: id: str path: str duration: float @dataclass class Sample: id: str path: str epoch: int duration: float conditioning: tp.Optional[tp.Dict[str, tp.Any]] prompt: tp.Optional[ReferenceSample] reference: tp.Optional[ReferenceSample] generation_args: tp.Optional[tp.Dict[str, tp.Any]] def __hash__(self): return hash(self.id) def audio(self) -> tp.Tuple[torch.Tensor, int]: return audio_read(self.path) def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: return audio_read(self.prompt.path) if self.prompt is not None else None def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]: return audio_read(self.reference.path) if self.reference is not None else None class SampleManager: """Audio samples IO handling within a given dora xp. The sample manager handles the dumping and loading logic for generated and references samples across epochs for a given xp, providing a simple API to store, retrieve and compare audio samples. Args: xp (dora.XP): Dora experiment object. The XP contains information on the XP folder where all outputs are stored and the configuration of the experiment, which is useful to retrieve audio-related parameters. map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples instead of generating a dedicated hash id. This is useful to allow easier comparison with ground truth sample from the files directly without having to read the JSON metadata to do the mapping (at the cost of potentially dumping duplicate prompts/references depending on the task). """ def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False): self.xp = xp self.base_folder: Path = xp.folder / xp.cfg.generate.path self.reference_folder = self.base_folder / 'reference' self.map_reference_to_sample_id = map_reference_to_sample_id self.samples: tp.List[Sample] = [] self._load_samples() @property def latest_epoch(self): """Latest epoch across all samples.""" return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0 def _load_samples(self): """Scan the sample folder and load existing samples.""" jsons = self.base_folder.glob('**/*.json') with ThreadPoolExecutor(6) as pool: self.samples = list(pool.map(self._load_sample, jsons)) @staticmethod @lru_cache(2**26) def _load_sample(json_file: Path) -> Sample: with open(json_file, 'r') as f: data: tp.Dict[str, tp.Any] = json.load(f) # fetch prompt data prompt_data = data.get('prompt') prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'], duration=prompt_data['duration']) if prompt_data else None # fetch reference data reference_data = data.get('reference') reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'], duration=reference_data['duration']) if reference_data else None # build sample object return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'], prompt=prompt, conditioning=data.get('conditioning'), reference=reference, generation_args=data.get('generation_args')) def _init_hash(self): return hashlib.sha1() def _get_tensor_id(self, tensor: torch.Tensor) -> str: hash_id = self._init_hash() hash_id.update(tensor.numpy().data) return hash_id.hexdigest() def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor], conditions: tp.Optional[tp.Dict[str, str]]) -> str: """Computes an id for a sample given its input data. This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input. Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned. Args: index (int): Batch index, Helpful to differentiate samples from the same batch. prompt_wav (torch.Tensor): Prompt used during generation. conditions (dict[str, str]): Conditioning used during generation. """ # For totally unconditioned generations we will just use a random UUID. # The function get_samples_for_xps will do a simple ordered match with a custom key. if prompt_wav is None and not conditions: return f"noinput_{uuid.uuid4().hex}" # Human readable portion hr_label = "" # Create a deterministic id using hashing hash_id = self._init_hash() hash_id.update(f"{index}".encode()) if prompt_wav is not None: hash_id.update(prompt_wav.numpy().data) hr_label += "_prompted" else: hr_label += "_unprompted" if conditions: encoded_json = json.dumps(conditions, sort_keys=True).encode() hash_id.update(encoded_json) cond_str = "-".join([f"{key}={slugify(value)}" for key, value in sorted(conditions.items())]) cond_str = cond_str[:100] # some raw text might be too long to be a valid filename cond_str = cond_str if len(cond_str) > 0 else "unconditioned" hr_label += f"_{cond_str}" else: hr_label += "_unconditioned" return hash_id.hexdigest() + hr_label def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path: """Stores the audio with the given stem path using the XP's configuration. Args: wav (torch.Tensor): Audio to store. stem_path (Path): Path in sample output directory with file stem to use. overwrite (bool): When False (default), skips storing an existing audio file. Returns: Path: The path at which the audio is stored. """ existing_paths = [ path for path in stem_path.parent.glob(stem_path.stem + '.*') if path.suffix != '.json' ] exists = len(existing_paths) > 0 if exists and overwrite: logger.warning(f"Overwriting existing audio file with stem path {stem_path}") elif exists: return existing_paths[0] audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio) return audio_path def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0, conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None, ground_truth_wav: tp.Optional[torch.Tensor] = None, generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample: """Adds a single sample. The sample is stored in the XP's sample output directory, under a corresponding epoch folder. Each sample is assigned an id which is computed using the input data. In addition to the sample itself, a json file containing associated metadata is stored next to it. Args: sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape]. epoch (int): current training epoch. index (int): helpful to differentiate samples from the same batch. conditions (dict[str, str], optional): conditioning used during generation. prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape]. ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from. Tensor of shape [channels, shape]. generation_args (dict[str, any], optional): dictionary of other arguments used during generation. Returns: Sample: The saved sample. """ sample_id = self._get_sample_id(index, prompt_wav, conditions) reuse_id = self.map_reference_to_sample_id prompt, ground_truth = None, None if prompt_wav is not None: prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True)) prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id) prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration) if ground_truth_wav is not None: ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True)) ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id) ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration) sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True) duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args) self.samples.append(sample) with open(sample_path.with_suffix('.json'), 'w') as f: json.dump(asdict(sample), f, indent=2) return sample def add_samples(self, samples_wavs: torch.Tensor, epoch: int, conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None, prompt_wavs: tp.Optional[torch.Tensor] = None, ground_truth_wavs: tp.Optional[torch.Tensor] = None, generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]: """Adds a batch of samples. The samples are stored in the XP's sample output directory, under a corresponding epoch folder. Each sample is assigned an id which is computed using the input data and their batch index. In addition to the sample itself, a json file containing associated metadata is stored next to it. Args: sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape]. epoch (int): Current training epoch. conditioning (list of dict[str, str], optional): List of conditions used during generation, one per sample in the batch. prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape [batch_size, channels, shape]. ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from. Tensor of shape [batch_size, channels, shape]. generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation. Returns: samples (list of Sample): The saved audio samples with prompts, ground truth and metadata. """ samples = [] for idx, wav in enumerate(samples_wavs): prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None conditions = conditioning[idx] if conditioning is not None else None samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args)) return samples def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False, exclude_unprompted: bool = False, exclude_conditioned: bool = False, exclude_unconditioned: bool = False) -> tp.Set[Sample]: """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain. Please note that existing samples are loaded during the manager's initialization, and added samples through this manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager is the only way detect them. Args: epoch (int): If provided, only return samples corresponding to this epoch. max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch. exclude_prompted (bool): If True, does not include samples that used a prompt. exclude_unprompted (bool): If True, does not include samples that did not use a prompt. exclude_conditioned (bool): If True, excludes samples that used conditioning. exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. Returns: Samples (set of Sample): The retrieved samples matching the provided filters. """ if max_epoch >= 0: samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch) else: samples_epoch = self.latest_epoch if epoch < 0 else epoch samples = { sample for sample in self.samples if ( (sample.epoch == samples_epoch) and (not exclude_prompted or sample.prompt is None) and (not exclude_unprompted or sample.prompt is not None) and (not exclude_conditioned or not sample.conditioning) and (not exclude_unconditioned or sample.conditioning) ) } return samples def slugify(value: tp.Any, allow_unicode: bool = False): """Process string for safer file naming. Taken from https://github.com/django/django/blob/master/django/utils/text.py Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated dashes to single dashes. Remove characters that aren't alphanumerics, underscores, or hyphens. Convert to lowercase. Also strip leading and trailing whitespace, dashes, and underscores. """ value = str(value) if allow_unicode: value = unicodedata.normalize("NFKC", value) else: value = ( unicodedata.normalize("NFKD", value) .encode("ascii", "ignore") .decode("ascii") ) value = re.sub(r"[^\w\s-]", "", value.lower()) return re.sub(r"[-\s]+", "-", value).strip("-_") def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: # Create a dictionary of stable id -> sample per XP stable_samples_per_xp = [{ sample.id: sample for sample in samples if sample.prompt is not None or sample.conditioning } for samples in samples_per_xp] # Set of all stable ids stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()} # Dictionary of stable id -> list of samples. If an XP does not have it, assign None stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids} # Filter out ids that contain None values (we only want matched samples after all) # cast is necessary to avoid mypy linter errors. return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples} def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]: # For unstable ids, we use a sorted list since we'll match them in order unstable_samples_per_xp = [[ sample for sample in sorted(samples, key=lambda x: x.id) if sample.prompt is None and not sample.conditioning ] for samples in samples_per_xp] # Trim samples per xp so all samples can have a match min_len = min([len(samples) for samples in unstable_samples_per_xp]) unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp] # Dictionary of index -> list of matched samples return { f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len) } def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]: """Gets a dictionary of matched samples across the given XPs. Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id will always match the number of XPs provided and will correspond to each XP in the same order given. In other words, only samples that can be match across all provided XPs will be returned in order to satisfy this rule. There are two types of ids that can be returned: stable and unstable. * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs (prompts/conditioning). This is why we can match them across XPs. * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples that used non-deterministic, random ids. This is the case for samples that did not use prompts or conditioning for their generation. This function will sort these samples by their id and match them by their index. Args: xps: a list of XPs to match samples from. start_epoch (int): If provided, only return samples corresponding to this epoch or newer. end_epoch (int): If provided, only return samples corresponding to this epoch or older. exclude_prompted (bool): If True, does not include samples that used a prompt. exclude_unprompted (bool): If True, does not include samples that did not use a prompt. exclude_conditioned (bool): If True, excludes samples that used conditioning. exclude_unconditioned (bool): If True, excludes samples that did not use conditioning. """ managers = [SampleManager(xp) for xp in xps] samples_per_xp = [manager.get_samples(**kwargs) for manager in managers] stable_samples = _match_stable_samples(samples_per_xp) unstable_samples = _match_unstable_samples(samples_per_xp) return dict(stable_samples, **unstable_samples) ================================================ FILE: audiocraft/utils/ui.py ================================================ from pathlib import Path import gradio as gr import torch refresh_symbol = '\U0001f504' # 🔄 class ToolButton(gr.Button, gr.components.IOComponent): """Small button with single emoji as text, fits inside gradio forms""" def __init__(self, **kwargs): super().__init__(**kwargs) def get_block_name(self): return "button" def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class): def refresh(): refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args for k, v in args.items(): setattr(refresh_component, k, v) return gr.update(**(args or {})) refresh_button = ToolButton(value=refresh_symbol, elem_classes=elem_class, scale=1, size="sm", container=False) refresh_button.click( fn=refresh, inputs=[], outputs=[refresh_component] ) return refresh_button ================================================ FILE: audiocraft/utils/utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from concurrent.futures import ProcessPoolExecutor from contextlib import contextmanager from functools import wraps, lru_cache import hashlib import json import logging from pathlib import Path import typing as tp import flashy import flashy.distrib import omegaconf import torch from torch.nn.utils.rnn import pad_sequence logger = logging.getLogger(__name__) def model_hash(model: torch.nn.Module) -> str: """Return a model hash. This should allow us to track regressions in model init from the logs of past experiments. """ hasher = hashlib.sha1() for p in model.parameters(): hasher.update(p.data.cpu().numpy().tobytes()) return hasher.hexdigest() def dict_from_config(cfg: omegaconf.DictConfig) -> dict: """Convenience function to map an omegaconf configuration to a dictionary. Args: cfg (omegaconf.DictConfig): Original configuration to map to dict. Returns: dict: Config as dictionary object. """ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) assert isinstance(dct, dict) return dct def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: if max_samples >= len(dataset): return dataset generator = torch.Generator().manual_seed(seed) perm = torch.randperm(len(dataset), generator=generator) return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: """Convenience function to load dataset into a dataloader with optional subset sampling. Args: dataset: Dataset to load. num_samples (Optional[int]): Number of samples to limit subset size. batch_size (int): Batch size. num_workers (int): Number of workers for data loading. seed (int): Random seed. """ if num_samples is not None: dataset = random_subset(dataset, num_samples, seed) dataloader = flashy.distrib.loader( dataset, batch_size=batch_size, num_workers=num_workers, **kwargs ) return dataloader def get_dataset_from_loader(dataloader): dataset = dataloader.dataset if isinstance(dataset, torch.utils.data.Subset): return dataset.dataset else: return dataset def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. Args: input (torch.Tensor): The input tensor containing probabilities. num_samples (int): Number of samples to draw. replacement (bool): Whether to draw with replacement or not. Keywords args: generator (torch.Generator): A pseudorandom number generator for sampling. Returns: torch.Tensor: Last dimension contains num_samples indices sampled from the multinomial probability distribution located in the last dimension of tensor input. """ input_ = input.reshape(-1, input.shape[-1]) output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) output = output_.reshape(*list(input.shape[:-1]), -1) return output def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: """Sample next token from top K values along the last dimension of the input probs tensor. Args: probs (torch.Tensor): Input probabilities with token candidates on the last dimension. k (int): The k in “top-k”. Returns: torch.Tensor: Sampled tokens. """ top_k_value, _ = torch.topk(probs, k, dim=-1) min_value_top_k = top_k_value[..., [-1]] probs *= (probs >= min_value_top_k).float() probs.div_(probs.sum(dim=-1, keepdim=True)) next_token = multinomial(probs, num_samples=1) return next_token def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: """Sample next token from top P probabilities along the last dimension of the input probs tensor. Args: probs (torch.Tensor): Input probabilities with token candidates on the last dimension. p (int): The p in “top-p”. Returns: torch.Tensor: Sampled tokens. """ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort *= (~mask).float() probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token class DummyPoolExecutor: """Dummy pool executor to use when we actually have only 1 worker. (e.g. instead of ProcessPoolExecutor). """ class DummyResult: def __init__(self, func, *args, **kwargs): self.func = func self.args = args self.kwargs = kwargs def result(self): return self.func(*self.args, **self.kwargs) def __init__(self, workers, mp_context=None): pass def submit(self, func, *args, **kwargs): return DummyPoolExecutor.DummyResult(func, *args, **kwargs) def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): return def get_pool_executor(num_workers: int, mp_context=None): return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] Args: lengths (torch.Tensor): tensor with lengths max_len (int): can set the max length manually. Defaults to None. Returns: torch.Tensor: mask with 0s where there is pad tokens else 1s """ assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." final_length = lengths.max().item() if not max_len else max_len final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] def hash_trick(word: str, vocab_size: int) -> int: """Hash trick to pair each word with an index Args: word (str): word we wish to convert to an index vocab_size (int): size of the vocabulary Returns: int: index of the word in the embedding LUT """ hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) return hash % vocab_size def with_rank_rng(base_seed: int = 1234): """Decorator for a function so that the function will use a Random Number Generator whose state depend on the GPU rank. The original RNG state is restored upon returning. Args: base_seed (int): Random seed. """ def _decorator(fun: tp.Callable): @wraps(fun) def _decorated(*args, **kwargs): state = torch.get_rng_state() seed = base_seed ^ flashy.distrib.rank() torch.manual_seed(seed) logger.debug('Rank dependent seed set to %d', seed) try: return fun(*args, **kwargs) finally: torch.set_rng_state(state) logger.debug('RNG state restored.') return _decorated return _decorator def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: """Get a list of tensors and collate them to a single tensor. according to the following logic: - `dim` specifies the time dimension which will be stacked and padded. - The output will contain 1 new dimension (dimension index 0) which will be the size of of the original list. Args: tensors (tp.List[torch.Tensor]): List of tensors to collate. dim (int): Dimension which will be stacked and padded. Returns: tp.Tuple[torch.Tensor, torch.Tensor]: torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension (dimension index 0) which will be the size of the original list. torch.Tensor: Tensor containing length of original tensor sizes (without padding). """ tensors = [x.transpose(0, dim) for x in tensors] lens = torch.LongTensor([len(x) for x in tensors]) padded_tensors = pad_sequence(tensors) padded_tensors = padded_tensors.transpose(0, 1) padded_tensors = padded_tensors.transpose(1, dim + 1) return padded_tensors, lens # TODO: Move to flashy? def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu', dtype: tp.Optional[torch.dtype] = None) -> tp.Any: if isinstance(state, torch.Tensor): if dtype is None or not state.is_floating_point(): dtype = state.dtype return state.detach().to(device=device, dtype=dtype, copy=True) elif isinstance(state, dict): return {k: copy_state(v, device, dtype) for k, v in state.items()} elif isinstance(state, list): return [copy_state(v, device, dtype) for v in state] # TODO: Move to flashy? @contextmanager def swap_state(model, state, **kwargs): old_state = copy_state(model.state_dict()) model.load_state_dict(state, **kwargs) try: yield finally: model.load_state_dict(old_state) @lru_cache(None) def warn_once(logger, msg): """Warn about a given message only once.""" logger.warning(msg) def is_jsonable(x: tp.Any): """Check if an object can be serialized into a json:""" try: json.dumps(x) return True except (TypeError, OverflowError): return False def load_clap_state_dict(clap_model, path: tp.Union[str, Path]): """Wrapper around state dict loading of CLAP model addressing compatibility issues between CLAP and AudioCraft HuggingFace transformer version. See: https://github.com/LAION-AI/CLAP/issues/118 """ from clap_module.factory import load_state_dict # type: ignore pkg = load_state_dict(path) pkg.pop('text_branch.embeddings.position_ids', None) clap_model.model.load_state_dict(pkg) ================================================ FILE: config/conditioner/chroma2music.yaml ================================================ # @package __global__ classifier_free_guidance: training_dropout: 0.2 inference_coef: 3.0 attribute_dropout: args: active_on_eval: false text: {} wav: self_wav: 0.5 fuser: cross_attention_pos_emb: false cross_attention_pos_emb_scale: 1 sum: [] prepend: [self_wav, description] cross: [] input_interpolate: [] conditioners: self_wav: model: chroma_stem chroma_stem: sample_rate: ${sample_rate} n_chroma: 12 radix2_exp: 14 argmax: true match_len_on_eval: false eval_wavs: null n_eval_wavs: 100 cache_path: null description: model: t5 t5: name: t5-base finetune: false word_dropout: 0.2 normalize_text: false dataset: train: merge_text_p: 0.25 drop_desc_p: 0.5 drop_other_p: 0.5 ================================================ FILE: config/conditioner/clapemb2music.yaml ================================================ # @package __global__ classifier_free_guidance: training_dropout: 0.3 inference_coef: 3.0 attribute_dropout: text: {} wav: {} fuser: cross_attention_pos_emb: false cross_attention_pos_emb_scale: 1 sum: [] prepend: [] cross: [description] input_interpolate: [] conditioners: description: model: clap clap: checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt model_arch: 'HTSAT-base' enable_fusion: false sample_rate: 44100 max_audio_length: 10 audio_stride: 1 dim: 512 attribute: description normalize: true quantize: true # use RVQ quantization n_q: 12 bins: 1024 kmeans_iters: 50 text_p: 0. # probability of using text embed at train time cache_path: null dataset: joint_embed_attributes: [description] train: merge_text_p: 0.25 drop_desc_p: 0.5 drop_other_p: 0.5 ================================================ FILE: config/conditioner/none.yaml ================================================ # @package __global__ # No conditioning classifier_free_guidance: training_dropout: 0 inference_coef: 1 attribute_dropout: text: {} wav: {} fuser: sum: [] prepend: [] cross: [] input_interpolate: [] conditioners: null ================================================ FILE: config/conditioner/text2music.yaml ================================================ # @package __global__ classifier_free_guidance: training_dropout: 0.3 inference_coef: 3.0 attribute_dropout: {} fuser: cross_attention_pos_emb: false cross_attention_pos_emb_scale: 1 sum: [] prepend: [] cross: [description] input_interpolate: [] conditioners: description: model: t5 t5: name: t5-base finetune: false word_dropout: 0.3 normalize_text: false dataset: train: merge_text_p: 0.25 drop_desc_p: 0.5 drop_other_p: 0.5 ================================================ FILE: config/conditioner/text2sound.yaml ================================================ # @package __global__ classifier_free_guidance: training_dropout: 0.1 inference_coef: 3.0 attribute_dropout: {} fuser: cross_attention_pos_emb: false cross_attention_pos_emb_scale: 1 sum: [] prepend: [] cross: [description] input_interpolate: [] conditioners: description: model: t5 t5: name: t5-large finetune: false word_dropout: 0. normalize_text: false ================================================ FILE: config/config.yaml ================================================ # WARNING: This is the base configuration file shared across ALL solvers in AudioCraft # Please don't update this file directly. Instead use distinct configuration files # to override the below configuration. defaults: - _self_ - dset: default - solver: default device: cuda dtype: float32 autocast: false autocast_dtype: bfloat16 seed: 2036 show: false # just show the model and its size and exit continue_from: # continue from a given sig or path execute_only: # can be set to generate/evaluate/valid to run that stage execute_inplace: false # don't enforce continue_from to be set # to enable inplace execution of the stage. This assume # that you know what you are doing and execute stage # preserving the original xp sig. benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them efficient_attention_backend: torch # can be torch or xformers. num_threads: 1 # called with torch.set_num_thread. mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server). label: # use this if you want twice the same exp, with a name. # logging parameters logging: level: INFO log_updates: 10 log_tensorboard: false log_wandb: false tensorboard: with_media_logging: false name: # optional name for the experiment sub_dir: # optional sub directory to store tensorboard data wandb: with_media_logging: true project: # project name name: # optional name for the experiment group: # optional group # SLURM launcher configuration. slurm: gpus: 4 # convenience parameter, number of GPUs to use. mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`. time: 3600 constraint: partition: comment: setup: [] exclude: '' # dora parameters dora: # Output folder for all artifacts of an experiment. dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs # The following entries will be ignored by dora when computing the unique XP signature. # Note that slurm.* and dora.* are automatically ignored. exclude: [ 'device', 'wandb.*', 'tensorboard.*', 'logging.*', 'dataset.num_workers', 'eval.num_workers', 'special.*', 'metrics.visqol.bin', 'metrics.fad.bin', 'execute_only', 'execute_best', 'generate.every', 'optim.eager_sync', 'profiler.*', 'deadlock.*', 'efficient_attention_backend', 'num_threads', 'mp_start_method', ] use_rendezvous: false # for grids, always run from a clean repo, allowing reliable runs and storing # the exact commit. Your repo must be absolutely pristine clean. # Local `dora run` are not impacted for easier debugging. git_save: true ================================================ FILE: config/dset/audio/audiocaps_16khz.yaml ================================================ # @package __global__ # AudioCaps dataset datasource: max_sample_rate: 16000 max_channels: 1 train: null # only evaluation set valid: null # only evaluation set evaluate: egs/audiocaps/audiocaps_16khz generate: egs/audiocaps/audiocaps_16khz # identical to evaluate ================================================ FILE: config/dset/audio/default.yaml ================================================ # @package __global__ datasource: max_sample_rate: ??? max_channels: ??? train: ??? valid: ??? evaluate: ??? generate: null ================================================ FILE: config/dset/audio/example.yaml ================================================ # @package __global__ datasource: max_sample_rate: 44100 max_channels: 2 train: egs/example valid: egs/example evaluate: egs/example generate: egs/example ================================================ FILE: config/dset/audio/musiccaps_32khz.yaml ================================================ # @package __global__ # total samples obtained from MusicCaps = 5469 # (out of 5521 due to AudioSet corrupted samples) datasource: max_sample_rate: 32000 max_channels: 2 train: null # only evaluation set valid: null # only evaluation set evaluate: egs/musiccaps/musiccaps_32khz generate: egs/musiccaps/musiccaps_32khz # identical to evaluate ================================================ FILE: config/dset/default.yaml ================================================ # @package __global__ # WARNING: This is a base configuration file shared across ALL solvers in AudioCraft # Please don't update this file directly. Instead use distinct configuration files # to override the below configuration. datasource: train: ??? valid: ??? evaluate: ??? generate: ??? ================================================ FILE: config/dset/internal/music_10k_32khz.yaml ================================================ # @package __global__ # high quality music dataset with no artist overlap between splits datasource: max_sample_rate: 32000 max_channels: 1 train: egs/music/music_10k_32khz/train valid: egs/music/music_10k_32khz/valid evaluate: egs/music/music_10k_32khz/test generate: egs/music/music_10k_32khz/test # identical to evaluate ================================================ FILE: config/dset/internal/music_400k_32khz.yaml ================================================ # @package __global__ datasource: max_sample_rate: 32000 max_channels: 1 train: egs/music/music_400k_32khz/train valid: egs/music/music_400k_32khz/valid evaluate: egs/music/music_400k_32khz/test generate: egs/music/music_400k_32khz/test # identical to evaluate ================================================ FILE: config/dset/internal/sounds_16khz.yaml ================================================ # @package __global__ # environmental sounds dataset compiling all datasets # with applied filters on tags datasource: max_sample_rate: 16000 max_channels: 1 train: egs/sound/sounds_16khz/train valid: egs/sound/sounds_16khz/valid evaluate: egs/sound/sounds_16khz/test generate: egs/sound/sounds_16khz/test # identical to evaluate ================================================ FILE: config/model/encodec/default.yaml ================================================ # @package __global__ compression_model: encodec encodec: autoencoder: seanet quantizer: rvq sample_rate: ${sample_rate} channels: ${channels} causal: false renormalize: false seanet: dimension: 128 channels: ${channels} causal: ${encodec.causal} n_filters: 32 n_residual_layers: 1 ratios: [8, 5, 4, 2] activation: ELU activation_params: {"alpha": 1.} norm: weight_norm norm_params: {} kernel_size: 7 residual_kernel_size: 3 last_kernel_size: 7 dilation_base: 2 pad_mode: constant true_skip: true compress: 2 lstm: 2 disable_norm_outer_blocks: 0 # Specific encoder or decoder params. # You can also override any param for the encoder or decoder only # by using Hydra `+param=` syntax, i.e.` # `+seanet.decoder.n_filters=64`. decoder: trim_right_ratio: 1.0 final_activation: null final_activation_params: null encoder: {} rvq: n_q: 8 q_dropout: false bins: 1024 decay: 0.99 kmeans_init: true kmeans_iters: 50 threshold_ema_dead_code: 2 orthogonal_reg_weight: 0.0 orthogonal_reg_active_codes_only: false no_quant: {} ================================================ FILE: config/model/encodec/encodec_base_causal.yaml ================================================ # @package __global__ defaults: - encodec/default encodec: causal: true rvq: n_q: 32 q_dropout: true ================================================ FILE: config/model/encodec/encodec_large_nq4_s320.yaml ================================================ # @package __global__ defaults: - encodec/default seanet: # default ratios are [8, 5, 4, 2] n_filters: 64 rvq: bins: 2048 n_q: 4 q_dropout: false ================================================ FILE: config/model/encodec/encodec_large_nq4_s640.yaml ================================================ # @package __global__ defaults: - encodec/default seanet: ratios: [8, 5, 4, 4] n_filters: 64 rvq: bins: 2048 n_q: 4 q_dropout: false ================================================ FILE: config/model/lm/audiogen_lm.yaml ================================================ # @package __global__ defaults: - lm/default - override /conditioner: text2sound - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly lm_model: transformer_lm codebooks_pattern: modeling: delay delay: delays: [0, 1, 2, 3] flatten_first: 0 empty_initial: 0 unroll: flattening: [0, 1, 2, 3] delays: [0, 0, 0, 0] music_lm: group_by: 2 valle: delays: [0, 0, 0] transformer_lm: n_q: 4 card: 2048 memory_efficient: true bias_proj: false bias_ff: false bias_attn: false norm_first: true layer_scale: null weight_init: gaussian depthwise_init: current zero_bias_init: true attention_as_float32: false ================================================ FILE: config/model/lm/default.yaml ================================================ # @package __global__ defaults: - _self_ - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly lm_model: transformer_lm codebooks_pattern: modeling: parallel transformer_lm: dim: 512 num_heads: 8 num_layers: 8 hidden_scale: 4 n_q: 8 # number of streams to model card: 1024 dropout: 0. emb_lr: null activation: gelu norm_first: false # use pre-norm instead of post-norm bias_ff: true # use bias for the feedforward bias_attn: true # use bias for the attention bias_proj: true # use bias for the output projections past_context: null causal: true custom: false # use custom MHA implementation memory_efficient: false # use flash attention attention_as_float32: false # use float32 for the attention part, # recommended at the moment when memory_efficient is True. layer_scale: null positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope). xpos: false # apply xpos decay (rope only). checkpointing: none # layer checkpointing method, can be none, torch, xformers_default. # torch is the slowest but uses the least memory, # xformers_default is somewhere in between. weight_init: null # weight initialization (null, gaussian or uniform) depthwise_init: null # perform depthwise initialization (null, current, global) zero_bias_init: false # initialize bias to zero if bias in linears and # if a weight_init method is used. norm: layer_norm # normalization method to use in transformer. cross_attention: false qk_layer_norm: false qk_layer_norm_cross: false attention_dropout: null kv_repeat: 1 two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not... ================================================ FILE: config/model/lm/model_scale/base.yaml ================================================ # @package __global__ # overrides nothing because default is already transformer base (~ 60M params) ================================================ FILE: config/model/lm/model_scale/large.yaml ================================================ # @package _global_ # gpt2 inspired, even bigger (~3.3B params) transformer_lm: dim: 2048 num_heads: 32 num_layers: 48 ================================================ FILE: config/model/lm/model_scale/medium.yaml ================================================ # @package _global_ # gpt2 like (~1.5B params) transformer_lm: dim: 1536 num_heads: 24 num_layers: 48 ================================================ FILE: config/model/lm/model_scale/small.yaml ================================================ # @package _global_ # 300M Param. transformer_lm: dim: 1024 num_heads: 16 num_layers: 24 ================================================ FILE: config/model/lm/model_scale/xsmall.yaml ================================================ # @package _global_ # just used for debugging or when we just want to populate the cache # and do not care about training. transformer_lm: dim: 64 num_heads: 2 num_layers: 2 ================================================ FILE: config/model/lm/musicgen_lm.yaml ================================================ # @package __global__ defaults: - lm/default - override /conditioner: text2music - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly lm_model: transformer_lm codebooks_pattern: modeling: delay delay: delays: [0, 1, 2, 3] flatten_first: 0 empty_initial: 0 unroll: flattening: [0, 1, 2, 3] delays: [0, 0, 0, 0] music_lm: group_by: 2 valle: delays: [0, 0, 0] transformer_lm: n_q: 4 card: 2048 memory_efficient: true bias_proj: false bias_ff: false bias_attn: false norm_first: true layer_scale: null weight_init: gaussian depthwise_init: current zero_bias_init: true attention_as_float32: false ================================================ FILE: config/model/none.yaml ================================================ # @package __global__ # This file exist so that model is recognized as a config group # by Hydra, and Dora. A bit weird we might need a better fix someday. ================================================ FILE: config/model/score/basic.yaml ================================================ # @package _global_ diffusion_unet: hidden: 48 depth: 4 res_blocks: 1 norm_groups: 4 kernel: 8 stride: 4 growth: 4 max_channels: 10_000 dropout: 0. emb_all_layers: true bilstm: false codec_dim: null transformer: false cross_attention: false ================================================ FILE: config/solver/audiogen/audiogen_base_16khz.yaml ================================================ # @package __global__ # This is the training loop solver # for the base AudioGen model (text-to-sound) # on monophonic audio sampled at 16 kHz # using a similar EnCodec+LM setup to MusicGen defaults: - audiogen/default - /model: lm/audiogen_lm - override /dset: audio/default - _self_ autocast: true autocast_dtype: float16 # EnCodec large trained on mono-channel music audio sampled at 16khz # with a total stride of 320 leading to 50 frames/s. # rvq.n_q=4, rvq.bins=2048, no quantization dropout # (transformer_lm card and n_q must be compatible) compression_model_checkpoint: //reference/bd44a852/checkpoint.th channels: 1 sample_rate: 16000 deadlock: use: true # deadlock detection dataset: batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128) num_workers: 10 segment_duration: 10 min_segment_ratio: 1.0 sample_on_weight: false # Uniform sampling all the way sample_on_duration: false # Uniform sampling all the way external_metadata_source: null # sample mixing augmentation at train time train: batch_size: 256 # matching AudioGen paper setup aug_p: 0.5 # perform audio mixing 50% of the time mix_p: 0.5 # proportion of batch items mixed together # important: note that this will reduce the # actual batch size used at train time # which will be equal to mix_p * batch_size mix_snr_low: -5 mix_snr_high: 5 mix_min_overlap: 0.5 generate: lm: use_sampling: true top_k: 250 top_p: 0.0 optim: epochs: 100 optimizer: adamw lr: 5e-4 ema: use: true updates: 10 device: cuda logging: log_tensorboard: true schedule: lr_scheduler: inverse_sqrt inverse_sqrt: warmup: 3000 warmup_init_lr: 0.0 ================================================ FILE: config/solver/audiogen/debug.yaml ================================================ # @package __global__ # This is a minimal debugging configuration # for MusicGen training solver defaults: - audiogen/default - /model: lm/audiogen_lm - override /model/lm/model_scale: xsmall - override /dset: audio/example - _self_ autocast: false compression_model_checkpoint: null codebooks_pattern: modeling: parallel channels: 1 sample_rate: 16000 deadlock: use: false # deadlock detection dataset: batch_size: 4 segment_duration: 5 sample_on_weight: false # Uniform sampling all the way sample_on_duration: false # Uniform sampling all the way generate: audio: strategy: peak lm: use_sampling: false top_k: 0 top_p: 0.0 checkpoint: save_every: 0 keep_last: 0 optim: epochs: 2 updates_per_epoch: 10 optimizer: adamw lr: 1e-4 logging: log_tensorboard: true schedule: lr_scheduler: null ================================================ FILE: config/solver/audiogen/default.yaml ================================================ # @package __global__ defaults: - /solver/musicgen/default - _self_ - /solver/audiogen/evaluation: none - override /dset: audio/default # See config/solver/musicgen/default.yaml for a list of possible values. # We only keep the most important here. autocast: true autocast_dtype: float16 solver: audiogen sample_rate: ??? channels: ??? compression_model_checkpoint: ??? tokens: padding_with_special_token: false dataset: batch_size: 128 segment_duration: 10 min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence. optim: epochs: 100 updates_per_epoch: 2000 lr: 1e-4 optimizer: adamw max_norm: 1.0 adam: betas: [0.9, 0.95] weight_decay: 0.1 eps: 1e-8 schedule: lr_scheduler: null ================================================ FILE: config/solver/audiogen/evaluation/none.yaml ================================================ # @package __global__ dataset: evaluate: num_samples: 10000 ================================================ FILE: config/solver/audiogen/evaluation/objective_eval.yaml ================================================ # @package __global__ # Setup for execute only on audiocaps for audio generation # evaluation with objective metrics # execute_only=evaluate dataset: max_audio_duration: null # ensure the proper values are broadcasted here for evaluate evaluate: min_audio_duration: 1. # some metrics requires a minimum audio length max_audio_duration: null # all samples from audiocaps should be ~10s num_samples: null segment_duration: null generate: min_audio_duration: 1. max_audio_duration: null num_samples: 500 evaluate: metrics: fad: true kld: true text_consistency: true metrics: kld: passt: pretrained_length: 10 # similarly to reported results in AudioGen paper ================================================ FILE: config/solver/compression/debug.yaml ================================================ # @package __global__ defaults: - compression/default - /model: encodec/encodec_base_causal - override /dset: audio/example - _self_ channels: 1 sample_rate: 16000 # debug config uses just L1 losses: adv: 0. feat: 0. l1: 1. mel: 0. msspec: 0. # no balancer balancer: balance_grads: false ema_decay: 1. total_norm: 1. per_batch_item: false # no adversaries adversarial: adversaries: [] adv_loss: hinge feat_loss: l1 # faster model for local dev seanet: dimension: 16 n_filters: 4 # very small dataset dataset: batch_size: 8 num_workers: 10 num_samples: 100 segment_duration: 1 evaluate: batch_size: 32 generate: batch_size: 1 num_samples: 5 segment_duration: 10 # limited training evaluate: every: 5 generate: every: 5 optim: epochs: 50 ================================================ FILE: config/solver/compression/default.yaml ================================================ # @package __global__ defaults: - ../default - override /dset: audio/default - _self_ solver: compression sample_rate: ??? channels: ??? # loss balancing losses: adv: 4. feat: 4. l1: 0.1 mel: 0. msspec: 2. sisnr: 0. balancer: balance_grads: true ema_decay: 0.999 per_batch_item: true total_norm: 1. adversarial: every: 1 adversaries: [msstftd] adv_loss: hinge feat_loss: l1 # losses hyperparameters l1: {} l2: {} mrstft: factor_sc: .5 factor_mag: .5 normalized: false mel: sample_rate: ${sample_rate} n_fft: 1024 hop_length: 256 win_length: 1024 n_mels: 64 f_min: 64 f_max: null normalized: false floor_level: 1e-5 sisnr: sample_rate: ${sample_rate} segment: 5. msspec: sample_rate: ${sample_rate} range_start: 6 range_end: 11 n_mels: 64 f_min: 64 f_max: null normalized: true alphas: false floor_level: 1e-5 # metrics metrics: visqol: mode: audio bin: null # path to visqol install model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 # adversaries hyperparameters msstftd: in_channels: 1 out_channels: 1 filters: 32 norm: weight_norm n_ffts: [1024, 2048, 512, 256, 128] hop_lengths: [256, 512, 128, 64, 32] win_lengths: [1024, 2048, 512, 256, 128] activation: LeakyReLU activation_params: {negative_slope: 0.3} msd: in_channels: 1 out_channels: 1 scale_norms: [spectral_norm, weight_norm, weight_norm] kernel_sizes: [5, 3] filters: 16 max_filters: 1024 downsample_scales: [4, 4, 4, 4] inner_kernel_sizes: null groups: [4, 4, 4, 4] strides: null paddings: null activation: LeakyReLU activation_params: {negative_slope: 0.3} mpd: in_channels: 1 out_channels: 1 periods: [2, 3, 5, 7, 11] n_layers: 5 kernel_size: 5 stride: 3 filters: 8 filter_scales: 4 max_filters: 1024 activation: LeakyReLU activation_params: {negative_slope: 0.3} norm: weight_norm # data hyperparameters dataset: batch_size: 64 num_workers: 10 segment_duration: 1 train: num_samples: 500000 valid: num_samples: 10000 evaluate: batch_size: 32 num_samples: 10000 generate: batch_size: 32 num_samples: 50 segment_duration: 10 # solver hyperparameters evaluate: every: 25 num_workers: 5 metrics: visqol: false sisnr: true generate: every: 25 num_workers: 5 audio: sample_rate: ${sample_rate} # checkpointing schedule checkpoint: save_last: true save_every: 25 keep_last: 10 keep_every_states: null # optimization hyperparameters optim: epochs: 200 updates_per_epoch: 2000 lr: 3e-4 max_norm: 0. optimizer: adam adam: betas: [0.5, 0.9] weight_decay: 0. ema: use: true # whether to use EMA or not updates: 1 # update at every step device: ${device} # device for EMA, can be put on GPU if more frequent updates decay: 0.99 # EMA decay value, if null, no EMA is used ================================================ FILE: config/solver/compression/encodec_audiogen_16khz.yaml ================================================ # @package __global__ defaults: - compression/default - /model: encodec/encodec_large_nq4_s320 - override /dset: audio/default - _self_ channels: 1 sample_rate: 16000 ================================================ FILE: config/solver/compression/encodec_base_24khz.yaml ================================================ # @package __global__ defaults: - compression/default - /model: encodec/encodec_base_causal - override /dset: audio/default - _self_ channels: 1 sample_rate: 24000 ================================================ FILE: config/solver/compression/encodec_musicgen_32khz.yaml ================================================ # @package __global__ defaults: - compression/default - /model: encodec/encodec_large_nq4_s640 - override /dset: audio/default - _self_ channels: 1 sample_rate: 32000 ================================================ FILE: config/solver/default.yaml ================================================ # @package __global__ # WARNING: This is a base configuration file shared across ALL solvers in AudioCraft # Please don't update this file directly. Instead use distinct configuration files # to override the below configuration. solver: ??? fsdp: use: false # should we use FSDP. param_dtype: float16 # equivalent to autocast_dtype for FSDP. reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability. buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it. sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard. # full_shard will use less memory but slower ?? per_block: true # If True, uses nested FSDP. profiler: enabled: false deadlock: use: false timeout: 600 dataset: batch_size: ??? num_workers: 10 segment_duration: null num_samples: null return_info: false shuffle: false sample_on_duration: true sample_on_weight: true min_segment_ratio: 0.5 train: num_samples: null shuffle: true shuffle_seed: 0 # if you want to sample the data differently. permutation_on_files: false valid: num_samples: null evaluate: num_samples: null generate: num_samples: null return_info: true checkpoint: save_last: true save_every: null keep_last: null keep_every_states: null generate: every: null path: 'samples' audio: format: 'mp3' strategy: 'clip' sample_rate: null lm: use_sampling: false temp: 1.0 top_k: 0 top_p: 0.0 evaluate: every: null num_workers: 5 truncate_audio: null fixed_generation_duration: null # in secs metrics: base: true # run default evaluation (e.g. like train/valid stage) optim: epochs: ??? updates_per_epoch: null lr: ??? optimizer: ??? adam: betas: [0.9, 0.999] weight_decay: 0. ema: use: false # whether to use EMA or not updates: ${optim.updates_per_epoch} # frequency of updates of the EMA device: cpu # device for EMA, can be put on GPU if more frequent updates decay: 0.99 # EMA decay value, if null, no EMA is used schedule: lr_scheduler: null step: step_size: null gamma: null exponential: lr_decay: null cosine: warmup: null lr_min_ratio: 0.0 cycle_length: 1.0 polynomial_decay: warmup: null zero_lr_warmup_steps: 0 end_lr: 0.0 power: 1 inverse_sqrt: warmup: null warmup_init_lr: 0.0 linear_warmup: warmup: null warmup_init_lr: 0.0 ================================================ FILE: config/solver/diffusion/debug.yaml ================================================ # @package __global__ defaults: - /solver/default - /model: score/basic - override /dset: audio/default - _self_ solver: diffusion sample_rate: 16000 channels: 1 compression_model_checkpoint: //sig/5091833e n_q: 2 # number of codebooks to keep dataset: batch_size: 8 num_workers: 10 segment_duration: 1 train: num_samples: 100 valid: num_samples: 100 evaluate: batch_size: 8 num_samples: 10 generate: batch_size: 8 num_samples: 10 segment_duration: 10 loss: kind: mse norm_power: 0. valid: every: 1 evaluate: every: 5 num_workers: 5 metrics: visqol: false sisnr: false rvm: true generate: every: 5 num_workers: 5 audio: sample_rate: ${sample_rate} checkpoint: save_last: true save_every: 25 keep_last: 10 keep_every_states: null optim: epochs: 50 updates_per_epoch: 2000 lr: 2e-4 max_norm: 0 optimizer: adam adam: betas: [0.9, 0.999] weight_decay: 0. ema: use: true # whether to use EMA or not updates: 1 # update at every step device: ${device} # device for EMA, can be put on GPU if more frequent updates decay: 0.99 # EMA decay value, if null, no EMA is used processor: name: multi_band_processor use: false n_bands: 8 num_samples: 10_000 power_std: 1. resampling: use: false target_sr: 16000 filter: use: false n_bands: 4 idx_band: 0 cutoffs: null schedule: repartition: "power" variable_step_batch: true beta_t0: 1.0e-5 beta_t1: 2.9e-2 beta_exp: 7.5 num_steps: 1000 variance: 'beta' clip: 5. rescale: 1. n_bands: null noise_scale: 1.0 metrics: num_stage: 4 ================================================ FILE: config/solver/diffusion/default.yaml ================================================ # @package __global__ defaults: - /solver/default - /model: score/basic - override /dset: audio/default - _self_ solver: diffusion sample_rate: ??? channels: ??? compression_model_checkpoint: ??? n_q: ??? # number of codebooks to keep dataset: batch_size: 128 num_workers: 10 segment_duration: 1 train: num_samples: 500000 valid: num_samples: 10000 evaluate: batch_size: 16 num_samples: 10000 generate: batch_size: 32 num_samples: 50 segment_duration: 10 audio: sample_rate: ${sample_rate} loss: kind: mse norm_power: 0. valid: every: 1 evaluate: every: 20 num_workers: 5 metrics: visqol: false sisnr: false rvm: true generate: every: 25 num_workers: 5 checkpoint: save_last: true save_every: 25 keep_last: 10 keep_every_states: null optim: epochs: 20000 updates_per_epoch: 2000 lr: 2e-4 max_norm: 0 optimizer: adam adam: betas: [0.9, 0.999] weight_decay: 0. ema: use: true # whether to use EMA or not updates: 1 # update at every step device: ${device} # device for EMA, can be put on GPU if more frequent updates decay: 0.99 # EMA decay value, if null, no EMA is used processor: name: multi_band_processor use: false n_bands: 8 num_samples: 10_000 power_std: 1. resampling: use: false target_sr: 16000 filter: use: false n_bands: 4 idx_band: 0 cutoffs: null schedule: repartition: "power" variable_step_batch: true beta_t0: 1.0e-5 beta_t1: 2.9e-2 beta_exp: 7.5 num_steps: 1000 variance: 'beta' clip: 5. rescale: 1. n_bands: null noise_scale: 1.0 metrics: num_stage: 4 ================================================ FILE: config/solver/diffusion/encodec_24khz.yaml ================================================ # @package __global__ defaults: - diffusion/default - _self_ sample_rate: 24000 channels: 1 compression_model_checkpoint: //pretrained/facebook/encodec_24khz n_q: 4 # num quantizers, 3kbps ================================================ FILE: config/solver/musicgen/debug.yaml ================================================ # @package __global__ # This is a minimal debugging configuration # for MusicGen training solver defaults: - musicgen/default - /model: lm/musicgen_lm - override /model/lm/model_scale: xsmall - override /dset: audio/example - _self_ autocast: false compression_model_checkpoint: //pretrained/debug_compression_model transformer_lm: n_q: 4 card: 400 codebooks_pattern: modeling: parallel channels: 1 sample_rate: 32000 deadlock: use: false # deadlock detection dataset: batch_size: 4 segment_duration: 5 sample_on_weight: false # Uniform sampling all the way sample_on_duration: false # Uniform sampling all the way generate: audio: strategy: peak lm: use_sampling: false top_k: 0 top_p: 0.0 checkpoint: save_every: 0 keep_last: 0 optim: epochs: 2 updates_per_epoch: 10 optimizer: adamw lr: 1e-4 logging: log_tensorboard: true schedule: lr_scheduler: null ================================================ FILE: config/solver/musicgen/default.yaml ================================================ # @package __global__ defaults: - /solver/default - /conditioner: none - _self_ - /solver/musicgen/evaluation: none - override /dset: audio/default autocast: true autocast_dtype: float16 solver: musicgen sample_rate: ??? channels: ??? compression_model_checkpoint: ??? tokens: padding_with_special_token: false cache: path: write: false write_shard: 0 write_num_shards: 1 dataset: batch_size: 128 num_workers: 10 segment_duration: 30 min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence. return_info: true train: num_samples: 1000000 # need a randomly large number here for AudioDataset valid: num_samples: 10000 generate: num_samples: 50 metrics: fad: use_gt: false model: tf tf: bin: null # path to local frechet_audio_distance code model_path: //reference/fad/vggish_model.ckpt kld: use_gt: false model: passt passt: pretrained_length: 20 text_consistency: use_gt: false model: clap clap: model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt model_arch: 'HTSAT-base' enable_fusion: false chroma_cosine: use_gt: false model: chroma_base chroma_base: sample_rate: ${sample_rate} n_chroma: 12 radix2_exp: 14 argmax: true generate: every: 25 num_workers: 5 path: samples audio: format: wav strategy: loudness sample_rate: ${sample_rate} loudness_headroom_db: 14 lm: prompted_samples: true unprompted_samples: true gen_gt_samples: false prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4 gen_duration: null # if not set, will use dataset.generate.segment_duration remove_prompts: false # generation params use_sampling: false temp: 1.0 top_k: 0 top_p: 0.0 evaluate: every: 25 num_workers: 5 metrics: base: false fad: false kld: false text_consistency: false chroma_cosine: false checkpoint: save_last: true save_every: 50 keep_last: 10 keep_every_states: null optim: epochs: 200 updates_per_epoch: 2000 lr: 1e-4 optimizer: adamw max_norm: 1.0 eager_sync: true adam: betas: [0.9, 0.95] weight_decay: 0.1 eps: 1e-8 schedule: lr_scheduler: null ================================================ FILE: config/solver/musicgen/evaluation/none.yaml ================================================ # @package __global__ dataset: evaluate: num_samples: 10000 ================================================ FILE: config/solver/musicgen/evaluation/objective_eval.yaml ================================================ # @package __global__ # Setup for execute only on musiccaps for audio generation # evaluation with objective metrics # execute_only=evaluate dataset: max_audio_duration: null # ensure the proper values are broadcasted here for evaluate evaluate: min_audio_duration: 1. # some metrics requires a minimum audio length max_audio_duration: null # all samples from musiccaps should be < 20s num_samples: null segment_duration: null generate: min_audio_duration: 1. max_audio_duration: null num_samples: 500 evaluate: metrics: fad: true kld: true text_consistency: true ================================================ FILE: config/solver/musicgen/musicgen_base_32khz.yaml ================================================ # @package __global__ # This is the training loop solver # for the base MusicGen model (text-to-music) # on monophonic audio sampled at 32 kHz defaults: - musicgen/default - /model: lm/musicgen_lm - override /dset: audio/default - _self_ autocast: true autocast_dtype: float16 # EnCodec large trained on mono-channel music audio sampled at 32khz # with a total stride of 640 leading to 50 frames/s. # rvq.n_q=4, rvq.bins=2048, no quantization dropout # (transformer_lm card and n_q must be compatible) compression_model_checkpoint: //pretrained/facebook/encodec_32khz channels: 1 sample_rate: 32000 deadlock: use: true # deadlock detection dataset: batch_size: 192 # 32 GPUs sample_on_weight: false # Uniform sampling all the way sample_on_duration: false # Uniform sampling all the way generate: lm: use_sampling: true top_k: 250 top_p: 0.0 optim: epochs: 500 optimizer: dadam lr: 1 ema: use: true updates: 10 device: cuda logging: log_tensorboard: true schedule: lr_scheduler: cosine cosine: warmup: 4000 lr_min_ratio: 0.0 cycle_length: 1.0 ================================================ FILE: config/solver/musicgen/musicgen_melody_32khz.yaml ================================================ # @package __global__ # This is the training loop solver # for the melody MusicGen model (text+chroma to music) # on monophonic audio sampled at 32 kHz defaults: - musicgen/default - /model: lm/musicgen_lm - override /conditioner: chroma2music - override /dset: audio/default - _self_ autocast: true autocast_dtype: float16 # EnCodec large trained on mono-channel music audio sampled at 32khz # with a total stride of 640 leading to 50 frames/s. # rvq.n_q=4, rvq.bins=2048, no quantization dropout # (transformer_lm card and n_q must be compatible) compression_model_checkpoint: //pretrained/facebook/encodec_32khz channels: 1 sample_rate: 32000 deadlock: use: true # deadlock detection dataset: batch_size: 192 # 32 GPUs sample_on_weight: false # Uniform sampling all the way sample_on_duration: false # Uniform sampling all the way generate: lm: use_sampling: true top_k: 250 top_p: 0.0 optim: epochs: 500 optimizer: dadam lr: 1 ema: use: true updates: 10 device: cuda logging: log_tensorboard: true schedule: lr_scheduler: cosine cosine: warmup: 4000 lr_min_ratio: 0.0 cycle_length: 1.0 ================================================ FILE: config/teams/default.yaml ================================================ default: dora_dir: /tmp/audiocraft_${oc.env:USER} partitions: global: debug team: debug reference_dir: /tmp darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc. dora_dir: /tmp/audiocraft_${oc.env:USER} partitions: global: debug team: debug reference_dir: /tmp ================================================ FILE: config/teams/labs.yaml ================================================ aws: dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs partitions: global: learnlab team: learnlab reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference dataset_mappers: "^/checkpoint/[a-z]+": "/fsx-audio-craft-llm" fair: dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs partitions: global: learnlab team: learnlab reference_dir: /large_experiments/audiocraft/reference dataset_mappers: "^/datasets01/datasets01": "/datasets01" darwin: dora_dir: /tmp/audiocraft_${oc.env:USER} partitions: global: debug team: debug reference_dir: /tmp rsc: dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs partitions: global: learn team: learn reference_dir: /checkpoint/audiocraft/shared/reference ================================================ FILE: dataset/example/electro_1.json ================================================ {"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]} ================================================ FILE: dataset/example/electro_2.json ================================================ {"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []} ================================================ FILE: demos/audiogen_demo.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# AudioGen\n", "Welcome to AudioGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use AudioGen in different settings.\n", "\n", "First, we start by initializing AudioGen. For now, we provide only a medium sized model for AudioGen: `facebook/audiogen-medium` - 1.5B transformer decoder. \n", "\n", "**Important note:** This variant is different from the original AudioGen model presented at [\"AudioGen: Textually-guided audio generation\"](https://arxiv.org/abs/2209.15352) as the model architecture is similar to MusicGen with a smaller frame rate and multiple streams of tokens, allowing to reduce generation time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.models import AudioGen\n", "\n", "model = AudioGen.get_pretrained('facebook/audiogen-medium')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let us configure the generation parameters. Specifically, you can control the following:\n", "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n", "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n", "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n", "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n", "* `duration` (float, optional): duration of the generated waveform. Defaults to 10.0.\n", "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n", "\n", "When left unchanged, AudioGen will revert to its default parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.set_generation_params(\n", " use_sampling=True,\n", " top_k=250,\n", " duration=5\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can go ahead and start generating sound using one of the following modes:\n", "* Audio continuation using `model.generate_continuation`\n", "* Text-conditional samples using `model.generate`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Audio Continuation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import torchaudio\n", "import torch\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "def get_bip_bip(bip_duration=0.125, frequency=440,\n", " duration=0.5, sample_rate=16000, device=\"cuda\"):\n", " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n", " t = torch.arange(\n", " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n", " wav = torch.cos(2 * math.pi * 440 * t)[None]\n", " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n", " envelope = (tp >= 0.5).float()\n", " return wav * envelope" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Here we use a synthetic signal to prompt the generated audio.\n", "res = model.generate_continuation(\n", " get_bip_bip(0.125).expand(2, -1, -1), \n", " 16000, ['Whistling with wind blowing', \n", " 'Typing on a typewriter'], \n", " progress=True)\n", "display_audio(res, 16000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n", "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/sirens_and_a_humming_engine_approach_and_pass.mp3\")\n", "prompt_duration = 2\n", "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n", "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n", "display_audio(output, sample_rate=16000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Text-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.utils.notebook import display_audio\n", "\n", "output = model.generate(\n", " descriptions=[\n", " 'Subway train blowing its horn',\n", " 'A cat meowing',\n", " ],\n", " progress=True\n", ")\n", "display_audio(output, sample_rate=16000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: demos/musicgen_app.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Updated to account for UI changes from https://github.com/rkfg/audiocraft/blob/long/app.py # also released under the MIT license. import argparse from concurrent.futures import ProcessPoolExecutor import os from pathlib import Path import subprocess as sp from tempfile import NamedTemporaryFile import time import typing as tp import warnings import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write from audiocraft.models import MusicGen, MultiBandDiffusion MODEL = None # Last used model IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') print(IS_BATCHED) MAX_BATCH_SIZE = 12 BATCHED_DURATION = 15 INTERRUPTING = False MBD = None # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call def _call_nostderr(*args, **kwargs): # Avoid ffmpeg vomiting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocating the pool of processes. pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True class FileCleaner: def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break file_cleaner = FileCleaner() def make_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') out = gr.make_waveform(*args, **kwargs) print("Make a video took", time.time() - be) return out def load_model(version='facebook/musicgen-melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = MusicGen.get_pretrained(version) def load_diffusion(): global MBD if MBD is None: print("loading MBD") MBD = MultiBandDiffusion.get_mbd_musicgen() def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] target_sr = 32000 target_ac = 1 for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) if any(m is not None for m in processed_melodies): outputs = MODEL.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, return_tokens=USE_DIFFUSION ) else: outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION) if USE_DIFFUSION: outputs_diffusion = MBD.tokens_to_wav(outputs[1]) outputs = torch.cat([outputs[0], outputs_diffusion], dim=0) outputs = outputs.detach().cpu().float() pending_videos = [] out_wavs = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) pending_videos.append(pool.submit(make_waveform, file.name)) out_wavs.append(file.name) file_cleaner.add(file.name) out_videos = [pending_video.result() for pending_video in pending_videos] for video in out_videos: file_cleaner.add(video) print("batch finished", len(texts), time.time() - be) print("Tempfiles currently stored: ", len(file_cleaner.files)) return out_videos, out_wavs def predict_batched(texts, melodies): max_text_length = 512 texts = [text[:max_text_length] for text in texts] load_model('facebook/musicgen-melody') res = _do_predictions(texts, melodies, BATCHED_DURATION) return res def predict_full(model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): global INTERRUPTING global USE_DIFFUSION INTERRUPTING = False if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: raise gr.Error("Topk must be non-negative.") if topp < 0: raise gr.Error("Topp must be non-negative.") topk = int(topk) if decoder == "MultiBand_Diffusion": USE_DIFFUSION = True load_diffusion() else: USE_DIFFUSION = False load_model(model) def _progress(generated, to_generate): progress((min(generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) videos, wavs = _do_predictions( [text], [melody], duration, progress=True, top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef) if USE_DIFFUSION: return videos[0], wavs[0], videos[1], wavs[1] return videos[0], wavs[0], None, None def toggle_audio_src(choice): if choice == "mic": return gr.update(source="microphone", value=None, label="Microphone") else: return gr.update(source="upload", value=None, label="File") def toggle_diffusion(choice): if choice == "MultiBand_Diffusion": return [gr.update(visible=True)] * 2 else: return [gr.update(visible=False)] * 2 def ui_full(launch_kwargs): with gr.Blocks() as interface: gr.Markdown( """ # MusicGen This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284) """ ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="Input Text", interactive=True) with gr.Column(): radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic") melody = gr.Audio(source="upload", type="numpy", label="File", interactive=True, elem_id="melody-input") with gr.Row(): submit = gr.Button("Submit") # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Row(): model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small", "facebook/musicgen-large"], label="Model", value="facebook/musicgen-melody", interactive=True) with gr.Row(): decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True) with gr.Row(): duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True) with gr.Row(): topk = gr.Number(label="Top-k", value=250, interactive=True) topp = gr.Number(label="Top-p", value=0, interactive=True) temperature = gr.Number(label="Temperature", value=1.0, interactive=True) cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) with gr.Column(): output = gr.Video(label="Generated Music") audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') diffusion_output = gr.Video(label="MultiBand Diffusion Decoder") audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath') submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False, show_progress=False).then(predict_full, inputs=[model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output, audio_output, diffusion_output, audio_diffusion]) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) gr.Examples( fn=predict_full, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", "facebook/musicgen-melody", "Default" ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", "facebook/musicgen-melody", "Default" ], [ "90s rock song with electric guitar and heavy drums", None, "facebook/musicgen-medium", "Default" ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", "./assets/bach.mp3", "facebook/musicgen-melody", "Default" ], [ "lofi slow bpm electro chill with organic samples", None, "facebook/musicgen-medium", "Default" ], [ "Punk rock with loud drum and power guitar", None, "facebook/musicgen-medium", "MultiBand_Diffusion" ], ], inputs=[text, melody, model, decoder], outputs=[output] ) gr.Markdown( """ ### More details The model will generate a short music extract based on the description you provided. The model can generate up to 30 seconds of audio in one pass. It is now possible to extend the generation by feeding back the end of the previous chunk of audio. This can take a long time, and the model might lose consistency. The model might also decide at arbitrary positions that the song ends. **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min). An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds are generated each time. We present 4 model variations: 1. facebook/musicgen-melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only. 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only. 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only. 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only. We also present two way of decoding the audio tokens 1. Use the default GAN based compression model 2. Use MultiBand Diffusion from (paper linknano ) When using `facebook/musicgen-melody`, you can optionally provide a reference audio from which a broad melody will be extracted. The model will then try to follow both the description and melody provided. You can also use your own GPU or a Google Colab by following the instructions on our repo. See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) for more details. """ ) interface.queue().launch(**launch_kwargs) def ui_batched(launch_kwargs): with gr.Blocks() as demo: gr.Markdown( """ # MusicGen This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
Duplicate Space for longer sequences, more control and no queue.

""" ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="Describe your music", lines=2, interactive=True) with gr.Column(): radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic") melody = gr.Audio(source="upload", type="numpy", label="File", interactive=True, elem_id="melody-input") with gr.Row(): submit = gr.Button("Generate") with gr.Column(): output = gr.Video(label="Generated Music") audio_output = gr.Audio(label="Generated Music (wav)", type='filepath') submit.click(predict_batched, inputs=[text, melody], outputs=[output, audio_output], batch=True, max_batch_size=MAX_BATCH_SIZE) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) gr.Examples( fn=predict_batched, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", ], [ "90s rock song with electric guitar and heavy drums", None, ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", "./assets/bach.mp3", ], [ "lofi slow bpm electro chill with organic samples", None, ], ], inputs=[text, melody], outputs=[output] ) gr.Markdown(""" ### More details The model will generate 12 seconds of audio based on the description you provided. You can optionally provide a reference audio from which a broad melody will be extracted. The model will then try to follow both the description and melody provided. All samples are generated with the `melody` model. You can also use your own GPU or a Google Colab by following the instructions on our repo. See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) for more details. """) demo.queue(max_size=8 * 4).launch(**launch_kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( '--listen', type=str, default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1', help='IP to listen on for connections to Gradio', ) parser.add_argument( '--username', type=str, default='', help='Username for authentication' ) parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) parser.add_argument( '--server_port', type=int, default=0, help='Port to run the server listener on', ) parser.add_argument( '--inbrowser', action='store_true', help='Open in browser' ) parser.add_argument( '--share', action='store_true', help='Share the gradio UI' ) args = parser.parse_args() launch_kwargs = {} launch_kwargs['server_name'] = args.listen if args.username and args.password: launch_kwargs['auth'] = (args.username, args.password) if args.server_port: launch_kwargs['server_port'] = args.server_port if args.inbrowser: launch_kwargs['inbrowser'] = args.inbrowser if args.share: launch_kwargs['share'] = args.share # Show the interface if IS_BATCHED: global USE_DIFFUSION USE_DIFFUSION = False ui_batched(launch_kwargs) else: ui_full(launch_kwargs) ================================================ FILE: demos/musicgen_demo.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MusicGen\n", "Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n", "\n", "First, we start by initializing MusicGen, you can choose a model from the following selection:\n", "1. `facebook/musicgen-small` - 300M transformer decoder.\n", "2. `facebook/musicgen-medium` - 1.5B transformer decoder.\n", "3. `facebook/musicgen-melody` - 1.5B transformer decoder also supporting melody conditioning.\n", "4. `facebook/musicgen-large` - 3.3B transformer decoder.\n", "\n", "We will use the `facebook/musicgen-small` variant for the purpose of this demonstration." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from audiocraft.models import MusicGen\n", "from audiocraft.models import MultiBandDiffusion\n", "import torch\n", "USE_DIFFUSION_DECODER = False\n", "# Using small model, better results would be obtained with `medium` or `large`.\n", "model = MusicGen.get_pretrained('facebook/musicgen-small')\n", "if USE_DIFFUSION_DECODER:\n", " mbd = MultiBandDiffusion.get_mbd_musicgen()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let us configure the generation parameters. Specifically, you can control the following:\n", "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n", "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n", "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n", "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n", "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n", "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n", "\n", "When left unchanged, MusicGen will revert to its default parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.set_generation_params(\n", " use_sampling=True,\n", " top_k=250,\n", " duration=30\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we can go ahead and start generating music using one of the following modes:\n", "* Unconditional samples using `model.generate_unconditional`\n", "* Music continuation using `model.generate_continuation`\n", "* Text-conditional samples using `model.generate`\n", "* Melody-conditional samples using `model.generate_with_chroma`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Music Continuation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "import torchaudio\n", "import torch\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "def get_bip_bip(bip_duration=0.125, frequency=440,\n", " duration=0.5, sample_rate=32000, device=\"cuda\"):\n", " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n", " t = torch.arange(\n", " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n", " wav = torch.cos(2 * math.pi * 440 * t)[None]\n", " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n", " envelope = (tp >= 0.5).float()\n", " return wav * envelope" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Here we use a synthetic signal to prompt both the tonality and the BPM\n", "# of the generated audio.\n", "res = model.generate_continuation(\n", " get_bip_bip(0.125).expand(2, -1, -1), \n", " 32000, ['Jazz jazz and only jazz', \n", " 'Heartful EDM with beautiful synths and chords'], \n", " progress=True)\n", "display_audio(res, 32000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n", "prompt_waveform, prompt_sr = torchaudio.load(\"../assets/bach.mp3\")\n", "prompt_duration = 2\n", "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n", "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True, return_tokens=True)\n", "display_audio(output[0], sample_rate=32000)\n", "if USE_DIFFUSION_DECODER:\n", " out_diffusion = mbd.tokens_to_wav(output[1])\n", " display_audio(out_diffusion, sample_rate=32000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Text-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from audiocraft.utils.notebook import display_audio\n", "\n", "output = model.generate(\n", " descriptions=[\n", " #'80s pop track with bassy drums and synth',\n", " #'90s rock song with loud guitars and heavy drums',\n", " #'Progressive rock drum and bass solo',\n", " #'Punk Rock song with loud drum and power guitar',\n", " #'Bluesy guitar instrumental with soulful licks and a driving rhythm section',\n", " #'Jazz Funk song with slap bass and powerful saxophone',\n", " 'drum and bass beat with intense percussions'\n", " ],\n", " progress=True, return_tokens=True\n", ")\n", "display_audio(output[0], sample_rate=32000)\n", "if USE_DIFFUSION_DECODER:\n", " out_diffusion = mbd.tokens_to_wav(output[1])\n", " display_audio(out_diffusion, sample_rate=32000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Melody-conditional Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchaudio\n", "from audiocraft.utils.notebook import display_audio\n", "\n", "model = MusicGen.get_pretrained('facebook/musicgen-melody')\n", "model.set_generation_params(duration=8)\n", "\n", "melody_waveform, sr = torchaudio.load(\"../assets/bach.mp3\")\n", "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n", "output = model.generate_with_chroma(\n", " descriptions=[\n", " '80s pop track with bassy drums and synth',\n", " '90s rock song with loud guitars and heavy drums',\n", " ],\n", " melody_wavs=melody_waveform,\n", " melody_sample_rate=sr,\n", " progress=True, return_tokens=True\n", ")\n", "display_audio(output[0], sample_rate=32000)\n", "if USE_DIFFUSION_DECODER:\n", " out_diffusion = mbd.tokens_to_wav(output[1])\n", " display_audio(out_diffusion, sample_rate=32000)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "vscode": { "interpreter": { "hash": "b02c911f9b3627d505ea4a19966a915ef21f28afb50dbf6b2115072d27c69103" } } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: dockerignore ================================================ cache/ ================================================ FILE: docs/AUDIOGEN.md ================================================ # AudioGen: Textually-guided audio generation AudioCraft provides the code and a model re-implementing AudioGen, a [textually-guided audio generation][audiogen_arxiv] model that performs text-to-sound generation. The provided AudioGen reimplementation follows the LM model architecture introduced in [MusicGen][musicgen_arxiv] and is a single stage auto-regressive Transformer model trained over a 16kHz EnCodec tokenizer with 4 codebooks sampled at 50 Hz. This model variant reaches similar audio quality than the original implementation introduced in the AudioGen publication while providing faster generation speed given the smaller frame rate. **Important note:** The provided models are NOT the original models used to report numbers in the [AudioGen publication][audiogen_arxiv]. Refer to the model card to learn more about architectural changes. Listen to samples from the **original AudioGen implementation** in our [sample page][audiogen_samples]. ## Model Card See [the model card](../model_cards/AUDIOGEN_MODEL_CARD.md). ## Installation Please follow the AudioCraft installation instructions from the [README](../README.md). AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). ## API and usage We provide a simple API and 1 pre-trained models for AudioGen: `facebook/audiogen-medium`: 1.5B model, text to sound - [🤗 Hub](https://huggingface.co/facebook/audiogen-medium) You can play with AudioGen by running the jupyter notebook at [`demos/audiogen_demo.ipynb`](../demos/audiogen_demo.ipynb) locally (if you have a GPU). See after a quick example for using the API. ```python import torchaudio from audiocraft.models import AudioGen from audiocraft.data.audio import audio_write model = AudioGen.get_pretrained('facebook/audiogen-medium') model.set_generation_params(duration=5) # generate 5 seconds. descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor'] wav = model.generate(descriptions) # generates 3 samples. for idx, one_wav in enumerate(wav): # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) ``` ## Training The [AudioGenSolver](../audiocraft/solvers/audiogen.py) implements the AudioGen's training pipeline used to develop the released model. Note that this may not fully reproduce the results presented in the paper. Similarly to MusicGen, it defines an autoregressive language modeling task over multiple streams of discrete tokens extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model) with dataset-specific changes for environmental sound processing. Note that **we do NOT provide any of the datasets** used for training AudioGen. ### Example configurations and grids We provide configurations to reproduce the released models and our research. AudioGen solvers configuration are available in [config/solver/audiogen](../config/solver/audiogen). The base training configuration used for the released models is the following: [`solver=audiogen/audiogen_base_16khz`](../config/solver/audiogen/audiogen_base_16khz.yaml) Please find some example grids to train AudioGen at [audiocraft/grids/audiogen](../audiocraft/grids/audiogen/). ```shell # text-to-sound dora grid audiogen.audiogen_base_16khz ``` ### Sound dataset and metadata AudioGen's underlying dataset is an AudioDataset augmented with description metadata. The AudioGen dataset implementation expects the metadata to be available as `.json` files at the same location as the audio files or through specified external folder. Learn more in the [datasets section](./DATASETS.md). ### Evaluation stage By default, evaluation stage is also computing the cross-entropy and the perplexity over the evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md) for more details on the requirements for each metric. We provide an off-the-shelf configuration to enable running the objective metrics for audio generation in [config/solver/audiogen/evaluation/objective_eval](../config/solver/audiogen/evaluation/objective_eval.yaml). One can then activate evaluation the following way: ```shell # using the configuration dora run solver=audiogen/debug solver/audiogen/evaluation=objective_eval # specifying each of the fields, e.g. to activate KL computation dora run solver=audiogen/debug evaluate.metrics.kld=true ``` See [an example evaluation grid](../audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py). ### Generation stage The generation stage allows to generate samples conditionally and/or unconditionally and to perform audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples generated and the batch size used are controlled by the `dataset.generate` configuration while the other generation parameters are defined in `generate.lm`. ```shell # control sampling parameters dora run solver=audiogen/debug generate.lm.gen_duration=5 generate.lm.use_sampling=true generate.lm.top_k=15 ``` ## More information Refer to [MusicGen's instructions](./MUSICGEN.md). ### Learn more Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). ## Citation AudioGen ``` @article{kreuk2022audiogen, title={Audiogen: Textually guided audio generation}, author={Kreuk, Felix and Synnaeve, Gabriel and Polyak, Adam and Singer, Uriel and D{\'e}fossez, Alexandre and Copet, Jade and Parikh, Devi and Taigman, Yaniv and Adi, Yossi}, journal={arXiv preprint arXiv:2209.15352}, year={2022} } ``` MusicGen ``` @article{copet2023simple, title={Simple and Controllable Music Generation}, author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, year={2023}, journal={arXiv preprint arXiv:2306.05284}, } ``` ## License See license information in the [model card](../model_cards/AUDIOGEN_MODEL_CARD.md). [audiogen_arxiv]: https://arxiv.org/abs/2209.15352 [musicgen_arxiv]: https://arxiv.org/abs/2306.05284 [audiogen_samples]: https://felixkreuk.github.io/audiogen/ ================================================ FILE: docs/CONDITIONING.md ================================================ # AudioCraft conditioning modules AudioCraft provides a [modular implementation of conditioning modules](../audiocraft/modules/conditioners.py) that can be used with the language model to condition the generation. The codebase was developed in order to easily extend the set of modules currently supported to easily develop new ways of controlling the generation. ## Conditioning methods For now, we support 3 main types of conditioning within AudioCraft: * Text-based conditioning methods * Waveform-based conditioning methods * Joint embedding conditioning methods for text and audio projected in a shared latent space. The Language Model relies on 2 core components that handle processing information: * The `ConditionProvider` class, that maps metadata to processed conditions leveraging all the defined conditioners for the given task. * The `ConditionFuser` class, that takes preprocessed conditions and properly fuse the conditioning embedding to the language model inputs following a given fusing strategy. Different conditioners (for text, waveform, joint embeddings...) are provided as torch modules in AudioCraft and are used internally in the language model to process the conditioning signals and feed them to the language model. ## Core concepts ### Conditioners The `BaseConditioner` torch module is the base implementation for all conditioners in audiocraft. Each conditioner is expected to implement 2 methods: * The `tokenize` method that is used as a preprocessing method that contains all processing that can lead to synchronization points (e.g. BPE tokenization with transfer to the GPU). The output of the tokenize method will then be used to feed the forward method. * The `forward` method that takes the output of the tokenize method and contains the core computation to obtain the conditioning embedding along with a mask indicating valid indices (e.g. padding tokens). ### ConditionProvider The ConditionProvider prepares and provides conditions given a dictionary of conditioners. Conditioners are specified as a dictionary of attributes and the corresponding conditioner providing the processing logic for the given attribute. Similarly to the conditioners, the condition provider works in two steps to avoid sychronization points: * A `tokenize` method that takes a list of conditioning attributes for the batch, and run all tokenize steps for the set of conditioners. * A `forward` method that takes the output of the tokenize step and run all the forward steps for the set of conditioners. The list of conditioning attributes is passed as a list of `ConditioningAttributes` that is presented just below. ### ConditionFuser Once all conditioning signals have been extracted and processed by the `ConditionProvider` as dense embeddings, they remain to be passed to the language model along with the original language model inputs. The `ConditionFuser` handles specifically the logic to combine the different conditions to the actual model input, supporting different strategies to combine them. One can therefore define different strategies to combine or fuse the condition to the input, in particular: * Prepending the conditioning signal to the input with the `prepend` strategy, * Summing the conditioning signal to the input with the `sum` strategy, * Combining the conditioning relying on a cross-attention mechanism with the `cross` strategy, * Using input interpolation with the `input_interpolate` strategy. ### SegmentWithAttributes and ConditioningAttributes: From metadata to conditions The `ConditioningAttributes` dataclass is the base class for metadata containing all attributes used for conditioning the language model. It currently supports the following types of attributes: * Text conditioning attributes: Dictionary of textual attributes used for text-conditioning. * Wav conditioning attributes: Dictionary of waveform attributes used for waveform-based conditioning such as the chroma conditioning. * JointEmbed conditioning attributes: Dictionary of text and waveform attributes that are expected to be represented in a shared latent space. These different types of attributes are the attributes that are processed by the different conditioners. `ConditioningAttributes` are extracted from metadata loaded along the audio in the datasets, provided that the metadata used by the dataset implements the `SegmentWithAttributes` abstraction. All metadata-enabled datasets to use for conditioning in AudioCraft inherits the [`audiocraft.data.info_dataset.InfoAudioDataset`](../audiocraft/data/info_audio_dataset.py) class and the corresponding metadata inherits and implements the `SegmentWithAttributes` abstraction. Refer to the [`audiocraft.data.music_dataset.MusicAudioDataset`](../audiocraft/data/music_dataset.py) class as an example. ## Available conditioners ### Text conditioners All text conditioners are expected to inherit from the `TextConditioner` class. AudioCraft currently provides two text conditioners: * The `LUTConditioner` that relies on look-up-table of embeddings learned at train time, and relying on either no tokenizer or a spacy tokenizer. This conditioner is particularly useful for simple experiments and categorical labels. * The `T5Conditioner` that relies on a [pre-trained T5 model](https://huggingface.co/docs/transformers/model_doc/t5) frozen or fine-tuned at train time to extract the text embeddings. ### Waveform conditioners All waveform conditioners are expected to inherit from the `WaveformConditioner` class and consists of conditioning method that takes a waveform as input. The waveform conditioner must implement the logic to extract the embedding from the waveform and define the downsampling factor from the waveform to the resulting embedding. The `ChromaStemConditioner` conditioner is a waveform conditioner for the chroma features conditioning used by MusicGen. It takes a given waveform, extract relevant stems for melody (namely all non drums and bass stems) using a [pre-trained Demucs model](https://github.com/facebookresearch/demucs) and then extract the chromagram bins from the remaining mix of stems. ### Joint embeddings conditioners We finally provide support for conditioning based on joint text and audio embeddings through the `JointEmbeddingConditioner` class and the `CLAPEmbeddingConditioner` that implements such a conditioning method relying on a [pretrained CLAP model](https://github.com/LAION-AI/CLAP). ## Classifier Free Guidance We provide a Classifier Free Guidance implementation in AudioCraft. With the classifier free guidance dropout, all attributes are dropped with the same probability. ## Attribute Dropout We further provide an attribute dropout strategy. Unlike the classifier free guidance dropout, the attribute dropout drops given attributes with a defined probability, allowing the model not to expect all conditioning signals to be provided at once. ## Faster computation of conditions Conditioners that require some heavy computation on the waveform can be cached, in particular the `ChromaStemConditioner` or `CLAPEmbeddingConditioner`. You just need to provide the `cache_path` parameter to them. We recommend running dummy jobs for filling up the cache quickly. An example is provied in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py). ================================================ FILE: docs/DATASETS.md ================================================ # AudioCraft datasets Our dataset manifest files consist in 1-json-per-line files, potentially gzipped, as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio file and associated metadata. The manifest files are then provided in the configuration, as `datasource` sub-configuration. A datasource contains the pointers to the paths of the manifest files for each AudioCraft stage (or split) along with additional information (eg. maximum sample rate to use against this dataset). All the datasources are under the `dset` group config, with a dedicated configuration file for each dataset. ## Getting started ### Example See the provided example in the directory that provides a manifest to use the example dataset provided under the [dataset folder](../dataset/example). The manifest files are stored in the [egs folder](../egs/example). ```shell egs/ example/data.json.gz ``` A datasource is defined in the configuration folder, in the dset group config for this dataset at [config/dset/audio/example](../config/dset/audio/example.yaml): ```shell # @package __global__ datasource: max_sample_rate: 44100 max_channels: 2 train: egs/example valid: egs/example evaluate: egs/example generate: egs/example ``` For proper dataset, one should create manifest for each of the splits and specify the correct path to the given manifest in the datasource for each split. Then, using a dataset through the configuration can be done pointing to the corresponding dataset configuration: ```shell dset= # should match the yaml file name # for example dset=audio/example ``` ### Creating manifest files Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use the following command to create new manifest files from a given folder containing audio files: ```shell python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz # For example to generate the manifest for dset=audio/example # note: we don't use any split and we don't compress the jsonl file for this dummy example python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl # More info with: python -m audiocraft.data.audio_dataset --help ``` ## Additional information ### MusicDataset and metadata The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects the additional metadata to be stored in a JSON file that has the same path as the corresponding audio file, but with a `.json` extension. ### SoundDataset and metadata The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset, the SoundDataset expects the additional metadata to be stored in a JSON file that has the same path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset supports an additional parameter pointing to an extra folder `external_metadata_source` containing all the JSON metadata files given they have the same filename as the audio file. ================================================ FILE: docs/ENCODEC.md ================================================ # EnCodec: High Fidelity Neural Audio Compression AudioCraft provides the training code for EnCodec, a state-of-the-art deep learning based audio codec supporting both mono stereo audio, presented in the [High Fidelity Neural Audio Compression][arxiv] paper. Check out our [sample page][encodec_samples]. ## Original EnCodec models The EnCodec models presented in High Fidelity Neural Audio Compression can be accessed and used with the [EnCodec repository](https://github.com/facebookresearch/encodec). **Note**: We do not guarantee compatibility between the AudioCraft and EnCodec codebases and released checkpoints at this stage. ## Installation Please follow the AudioCraft installation instructions from the [README](../README.md). ## Training The [CompressionSolver](../audiocraft/solvers/compression.py) implements the audio reconstruction task to train an EnCodec model. Specifically, it trains an encoder-decoder with a quantization bottleneck - a SEANet encoder-decoder with Residual Vector Quantization bottleneck for EnCodec - using a combination of objective and perceptual losses in the forms of discriminators. The default configuration matches a causal EnCodec training with at a single bandwidth. ### Example configuration and grids We provide sample configuration and grids for training EnCodec models. The compression configuration are defined in [config/solver/compression](../config/solver/compression). The example grids are available at [audiocraft/grids/compression](../audiocraft/grids/compression). ```shell # base causal encodec on monophonic audio sampled at 24 khz dora grid compression.encodec_base_24khz # encodec model used for MusicGen on monophonic audio sampled at 32 khz dora grid compression.encodec_musicgen_32khz ``` ### Training and valid stages The model is trained using a combination of objective and perceptual losses. More specifically, EnCodec is trained with the MS-STFT discriminator along with objective losses through the use of a loss balancer to effectively weight the different losses, in an intuitive manner. ### Evaluation stage Evaluations metrics for audio generation: * SI-SNR: Scale-Invariant Signal-to-Noise Ratio. * ViSQOL: Virtual Speech Quality Objective Listener. Note: Path to the ViSQOL binary (compiled with bazel) needs to be provided in order to run the ViSQOL metric on the reference and degraded signals. The metric is disabled by default. Please refer to the [metrics documentation](../METRICS.md) to learn more. ### Generation stage The generation stage consists in generating the reconstructed audio from samples with the current model. The number of samples generated and the batch size used are controlled by the `dataset.generate` configuration. The output path and audio formats are defined in the generate stage configuration. ```shell # generate samples every 5 epoch dora run solver=compression/encodec_base_24khz generate.every=5 # run with a different dset dora run solver=compression/encodec_base_24khz generate.path= # limit the number of samples or use a different batch size dora grid solver=compression/encodec_base_24khz dataset.generate.num_samples=10 dataset.generate.batch_size=4 ``` ### Playing with the model Once you have a model trained, it is possible to get the entire solver, or just the trained model with the following functions: ```python from audiocraft.solvers import CompressionSolver # If you trained a custom model with signature SIG. model = CompressionSolver.model_from_checkpoint('//sig/SIG') # If you want to get one of the pretrained models with the `//pretrained/` prefix. model = CompressionSolver.model_from_checkpoint('//pretrained/facebook/encodec_32khz') # Or load from a custom checkpoint path model = CompressionSolver.model_from_checkpoint('/my_checkpoints/foo/bar/checkpoint.th') # If you only want to use a pretrained model, you can also directly get it # from the CompressionModel base model class. from audiocraft.models import CompressionModel # Here do not put the `//pretrained/` prefix! model = CompressionModel.get_pretrained('facebook/encodec_32khz') model = CompressionModel.get_pretrained('dac_44khz') # Finally, you can also retrieve the full Solver object, with its dataloader etc. from audiocraft import train from pathlib import Path import logging import os import sys # uncomment the following line if you want some detailed logs when loading a Solver. logging.basicConfig(stream=sys.stderr, level=logging.INFO) # You must always run the following function from the root directory. os.chdir(Path(train.__file__).parent.parent) # You can also get the full solver (only for your own experiments). # You can provide some overrides to the parameters to make things more convenient. solver = train.get_solver_from_sig('SIG', {'device': 'cpu', 'dataset': {'batch_size': 8}}) solver.model solver.dataloaders ``` ### Importing / Exporting models At the moment we do not have a definitive workflow for exporting EnCodec models, for instance to Hugging Face (HF). We are working on supporting automatic convertion between AudioCraft and Hugging Face implementations. We still have some support for fine tuning an EnCodec model coming from HF in AudioCraft, using for instance `continue_from=//pretrained/facebook/encodec_32k`. An AudioCraft checkpoint can be exported in a more compact format (excluding the optimizer etc.) using `audiocraft.utils.export.export_encodec`. For instance, you could run ```python from audiocraft.utils import export from audiocraft import train xp = train.main.get_xp_from_sig('SIG') export.export_encodec( xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin') from audiocraft.models import CompressionModel model = CompressionModel.get_pretrained('/checkpoints/my_audio_lm/compression_state_dict.bin') from audiocraft.solvers import CompressionSolver # The two are strictly equivalent, but this function supports also loading from non already exported models. model = CompressionSolver.model_from_checkpoint('//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin') ``` We will see then how to use this model as a tokenizer for MusicGen/Audio gen in the [MusicGen documentation](./MUSICGEN.md). ### Learn more Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). ## Citation ``` @article{defossez2022highfi, title={High Fidelity Neural Audio Compression}, author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, journal={arXiv preprint arXiv:2210.13438}, year={2022} } ``` ## License See license information in the [README](../README.md). [arxiv]: https://arxiv.org/abs/2210.13438 [encodec_samples]: https://ai.honu.io/papers/encodec/samples.html ================================================ FILE: docs/MBD.md ================================================ # MultiBand Diffusion AudioCraft provides the code and models for MultiBand Diffusion, [From Discrete Tokens to High Fidelity Audio using MultiBand Diffusion][arxiv]. MultiBand diffusion is a collection of 4 models that can decode tokens from EnCodec tokenizer into waveform audio. You can listen to some examples on the sample page. Open In Colab
## Installation Please follow the AudioCraft installation instructions from the [README](../README.md). ## Usage We offer a number of way to use MultiBand Diffusion: 1. The MusicGen demo includes a toggle to try diffusion decoder. You can use the demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py), or through the [MusicGen Colab](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing). 2. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). ## API We provide a simple API and pre-trained models for MusicGen and for EnCodec at 24 khz for 3 bitrates (1.5 kbps, 3 kbps and 6 kbps). See after a quick example for using MultiBandDiffusion with the MusicGen API: ```python import torchaudio from audiocraft.models import MusicGen, MultiBandDiffusion from audiocraft.data.audio import audio_write model = MusicGen.get_pretrained('facebook/musicgen-melody') mbd = MultiBandDiffusion.get_mbd_musicgen() model.set_generation_params(duration=8) # generate 8 seconds. wav, tokens = model.generate_unconditional(4, return_tokens=True) # generates 4 unconditional audio samples and keep the tokens for MBD generation descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] wav_diffusion = mbd.tokens_to_wav(tokens) wav, tokens = model.generate(descriptions, return_tokens=True) # generates 3 samples and keep the tokens. wav_diffusion = mbd.tokens_to_wav(tokens) melody, sr = torchaudio.load('./assets/bach.mp3') # Generates using the melody from the given audio and the provided descriptions, returns audio and audio tokens. wav, tokens = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr, return_tokens=True) wav_diffusion = mbd.tokens_to_wav(tokens) for idx, one_wav in enumerate(wav): # Will save under {idx}.wav and {idx}_diffusion.wav, with loudness normalization at -14 db LUFS for comparing the methods. audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) audio_write(f'{idx}_diffusion', wav_diffusion[idx].cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) ``` For the compression task (and to compare with [EnCodec](https://github.com/facebookresearch/encodec)): ```python import torch from audiocraft.models import MultiBandDiffusion from encodec import EncodecModel from audiocraft.data.audio import audio_read, audio_write bandwidth = 3.0 # 1.5, 3.0, 6.0 mbd = MultiBandDiffusion.get_mbd_24khz(bw=bandwidth) encodec = EncodecModel.encodec_model_24khz() somepath = '' wav, sr = audio_read(somepath) with torch.no_grad(): compressed_encodec = encodec(wav) compressed_diffusion = mbd.regenerate(wav, sample_rate=sr) audio_write('sample_encodec', compressed_encodec.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True) audio_write('sample_diffusion', compressed_diffusion.squeeze(0).cpu(), mbd.sample_rate, strategy="loudness", loudness_compressor=True) ``` ## Training The [DiffusionSolver](../audiocraft/solvers/diffusion.py) implements our diffusion training pipeline. It generates waveform audio conditioned on the embeddings extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model). Note that **we do NOT provide any of the datasets** used for training our diffusion models. We provide a dummy dataset containing just a few examples for illustrative purposes. ### Example configurations and grids One can train diffusion models as described in the paper by using this [dora grid](../audiocraft/grids/diffusion/4_bands_base_32khz.py). ```shell # 4 bands MBD trainning dora grid diffusion.4_bands_base_32khz ``` ### Learn more Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). ## Citation ``` @article{sanroman2023fromdi, title={From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion}, author={San Roman, Robin and Adi, Yossi and Deleforge, Antoine and Serizel, Romain and Synnaeve, Gabriel and Défossez, Alexandre}, journal={arXiv preprint arXiv:}, year={2023} } ``` ## License See license information in the [README](../README.md). [arxiv]: https://dl.fbaipublicfiles.com/encodec/Diffusion/paper.pdf [mbd_samples]: https://ai.honu.io/papers/mbd/ ================================================ FILE: docs/METRICS.md ================================================ # AudioCraft objective metrics In addition to training losses, AudioCraft provides a set of objective metrics for audio synthesis and audio generation. As these metrics may require extra dependencies and can be costly to train, they are often disabled by default. This section provides guidance for setting up and using these metrics in the AudioCraft training pipelines. ## Available metrics ### Audio synthesis quality metrics #### SI-SNR We provide an implementation of the Scale-Invariant Signal-to-Noise Ratio in PyTorch. No specific requirement is needed for this metric. Please activate the metric at the evaluation stage with the appropriate flag: ```shell dora run <...> evaluate.metrics.sisnr=true ``` #### ViSQOL We provide a Python wrapper around the ViSQOL [official implementation](https://github.com/google/visqol) to conveniently run ViSQOL within the training pipelines. One must specify the path to the ViSQOL installation through the configuration in order to enable ViSQOL computations in AudioCraft: ```shell # the first parameter is used to activate visqol computation while the second specify # the path to visqol's library to be used by our python wrapper dora run <...> evaluate.metrics.visqol=true metrics.visqol.bin= ``` See an example grid: [Compression with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py) To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the instructions available in the [open source repository](https://github.com/google/visqol). ### Audio generation metrics #### Frechet Audio Distance Similarly to ViSQOL, we use a Python wrapper around the Frechet Audio Distance [official implementation](https://github.com/google-research/google-research/tree/master/frechet_audio_distance) in TensorFlow. Note that we had to make several changes to the actual code in order to make it work. Please refer to the [FrechetAudioDistanceMetric](../audiocraft/metrics/fad.py) class documentation for more details. We do not plan to provide further support in obtaining a working setup for the Frechet Audio Distance at this stage. ```shell # the first parameter is used to activate FAD metric computation while the second specify # the path to FAD library to be used by our python wrapper dora run <...> evaluate.metrics.fad=true metrics.fad.bin= ``` See an example grid: [Evaluation with FAD](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py) #### Kullback-Leibler Divergence We provide a PyTorch implementation of the Kullback-Leibler Divergence computed over the probabilities of the labels obtained by a state-of-the-art audio classifier. We provide our implementation of the KLD using the [PaSST classifier](https://github.com/kkoutini/PaSST). In order to use the KLD metric over PaSST, you must install the PaSST library as an extra dependency: ```shell pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' ``` Then similarly, you can use the metric activating the corresponding flag: ```shell # one could extend the kld metric with additional audio classifier models that can then be picked through the configuration dora run <...> evaluate.metrics.kld=true metrics.kld.model=passt ``` #### Text consistency We provide a text-consistency metric, similarly to the MuLan Cycle Consistency from [MusicLM](https://arxiv.org/pdf/2301.11325.pdf) or the CLAP score used in [Make-An-Audio](https://arxiv.org/pdf/2301.12661v1.pdf). More specifically, we provide a PyTorch implementation of a Text consistency metric relying on a pre-trained [Contrastive Language-Audio Pretraining (CLAP)](https://github.com/LAION-AI/CLAP). Please install the CLAP library as an extra dependency prior to using the metric: ```shell pip install laion_clap ``` Then similarly, you can use the metric activating the corresponding flag: ```shell # one could extend the text consistency metric with additional audio classifier models that can then be picked through the configuration dora run ... evaluate.metrics.text_consistency=true metrics.text_consistency.model=clap ``` Note that the text consistency metric based on CLAP will require the CLAP checkpoint to be provided in the configuration. #### Chroma cosine similarity Finally, as introduced in MusicGen, we provide a Chroma Cosine Similarity metric in PyTorch. No specific requirement is needed for this metric. Please activate the metric at the evaluation stage with the appropriate flag: ```shell dora run ... evaluate.metrics.chroma_cosine=true ``` #### Comparing against reconstructed audio For all the above audio generation metrics, we offer the option to compute the metric on the reconstructed audio fed in EnCodec instead of the generated sample using the flag `.use_gt=true`. ## Example usage You will find example of configuration for the different metrics introduced above in: * The [musicgen's default solver](../config/solver/musicgen/default.yaml) for all audio generation metrics * The [compression's default solver](../config/solver/compression/default.yaml) for all audio synthesis metrics Similarly, we provide different examples in our grids: * [Evaluation with ViSQOL](../audiocraft/grids/compression/encodec_musicgen_32khz.py) * [Evaluation with FAD and others](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py) ================================================ FILE: docs/MUSICGEN.md ================================================ # MusicGen: Simple and Controllable Music Generation AudioCraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive Transformer model trained over a 32kHz EnCodec tokenizer with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict them in parallel, thus having only 50 auto-regressive steps per second of audio. Check out our [sample page][musicgen_samples] or test the available demo! Open In Colab Open in HugginFace
We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. ## Model Card See [the model card](../model_cards/MUSICGEN_MODEL_CARD.md). ## Installation Please follow the AudioCraft installation instructions from the [README](../README.md). AudioCraft requires a GPU with at least 16 GB of memory for running inference with the medium-sized models (~1.5B parameters). ## Usage We offer a number of way to interact with MusicGen: 1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support). 2. You can run the extended demo on a Colab: [colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing) 3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py). 4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). 5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) which is regularly updated with contributions from @camenduru and the community. ## API We provide a simple API and 4 pre-trained models. The pre trained models are: - `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small) - `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium) - `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody) - `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large) We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model. In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller GPUs will be able to generate short sequences, or longer sequences with the `facebook/musicgen-small` model. See after a quick example for using the API. ```python import torchaudio from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write model = MusicGen.get_pretrained('facebook/musicgen-melody') model.set_generation_params(duration=8) # generate 8 seconds. wav = model.generate_unconditional(4) # generates 4 unconditional audio samples descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] wav = model.generate(descriptions) # generates 3 samples. melody, sr = torchaudio.load('./assets/bach.mp3') # generates using the melody from the given audio and the provided descriptions. wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr) for idx, one_wav in enumerate(wav): # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) ``` ## 🤗 Transformers Usage MusicGen is available in the 🤗 Transformers library from version 4.31.0 onwards, requiring minimal dependencies and additional packages. Steps to get started: 1. First install the 🤗 [Transformers library](https://github.com/huggingface/transformers) from main: ```shell pip install git+https://github.com/huggingface/transformers.git ``` 2. Run the following Python code to generate text-conditional audio samples: ```py from transformers import AutoProcessor, MusicgenForConditionalGeneration processor = AutoProcessor.from_pretrained("facebook/musicgen-small") model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") inputs = processor( text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], padding=True, return_tensors="pt", ) audio_values = model.generate(**inputs, max_new_tokens=256) ``` 3. Listen to the audio samples either in an ipynb notebook: ```py from IPython.display import Audio sampling_rate = model.config.audio_encoder.sampling_rate Audio(audio_values[0].numpy(), rate=sampling_rate) ``` Or save them as a `.wav` file using a third-party library, e.g. `scipy`: ```py import scipy sampling_rate = model.config.audio_encoder.sampling_rate scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy()) ``` For more details on using the MusicGen model for inference using the 🤗 Transformers library, refer to the [MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on [Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb). ## Training The [MusicGenSolver](../audiocraft/solvers/musicgen.py) implements MusicGen's training pipeline. It defines an autoregressive language modeling task over multiple streams of discrete tokens extracted from a pre-trained EnCodec model (see [EnCodec documentation](./ENCODEC.md) for more details on how to train such model). Note that **we do NOT provide any of the datasets** used for training MusicGen. We provide a dummy dataset containing just a few examples for illustrative purposes. Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section. ### Example configurations and grids We provide configurations to reproduce the released models and our research. MusicGen solvers configuration are available in [config/solver/musicgen](../config/solver/musicgen), in particular: * MusicGen base model for text-to-music: [`solver=musicgen/musicgen_base_32khz`](../config/solver/musicgen/musicgen_base_32khz.yaml) * MusicGen model with chromagram-conditioning support: [`solver=musicgen/musicgen_melody_32khz`](../config/solver/musicgen/musicgen_melody_32khz.yaml) We provide 3 different scales, e.g. `model/lm/model_scale=small` (300M), or `medium` (1.5B), and `large` (3.3B). Please find some example grids to train MusicGen at [audiocraft/grids/musicgen](../audiocraft/grids/musicgen/). ```shell # text-to-music dora grid musicgen.musicgen_base_32khz --dry_run --init # melody-guided music generation dora grid musicgen.musicgen_melody_base_32khz --dry_run --init # Remove the `--dry_run --init` flags to actually schedule the jobs once everything is setup. ``` ### Music dataset and metadata MusicGen's underlying dataset is an AudioDataset augmented with music-specific metadata. The MusicGen dataset implementation expects the metadata to be available as `.json` files at the same location as the audio files. Learn more in the [datasets section](./DATASETS.md). ### Audio tokenizers We support a number of audio tokenizers: either pretrained EnCodec models, [DAC](https://github.com/descriptinc/descript-audio-codec), or your own models. The tokenizer is controlled with the setting `compression_model_checkpoint`. For instance, ```bash # Using the 32kHz EnCodec trained on music dora run solver=musicgen/debug \ compression_model_checkpoint=//pretrained/facebook/encodec_32khz \ transformer_lm.n_q=4 transformer_lm.card=2048 # Using DAC dora run solver=musicgen/debug \ compression_model_checkpoint=//pretrained/dac_44khz \ transformer_lm.n_q=9 transformer_lm.card=1024 \ 'codebooks_pattern.delay.delays=[0,1,2,3,4,5,6,7,8]' # Using your own model after export (see ENCODEC.md) dora run solver=musicgen/debug \ compression_model_checkpoint=//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin \ transformer_lm.n_q=... transformer_lm.card=... # Using your own model from its training checkpoint. dora run solver=musicgen/debug \ compression_model_checkpoint=//sig/SIG \ # where SIG is the Dora signature of the EnCodec XP. transformer_lm.n_q=... transformer_lm.card=... ``` **Warning:** you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. . ### Fine tuning existing models You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular ```bash # Using pretrained MusicGen model. dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//pretrained/facebook/musicgen-medium conditioner=text2music # Using another model you already trained with a Dora signature SIG. dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=//sig/SIG conditioner=text2music # Or providing manually a path dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continue_from=/checkpoints/my_other_xp/checkpoint.th ``` **Warning:** You are responsible for selecting the other parameters accordingly, in a way that make it compatible with the model you are fine tuning. Configuration is NOT automatically inherited from the model you continue from. In particular make sure to select the proper `conditioner` and `model/lm/model_scale`. **Warning:** We currently do not support fine tuning a model with slightly different layers. If you decide to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`. If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict `{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix. ### Caching of EnCodec tokens It is possible to precompute the EnCodec tokens and other metadata. An example of generating and using this cache provided in the [musicgen.musicgen_base_cached_32khz grid](../audiocraft/grids/musicgen/musicgen_base_cached_32khz.py). ### Evaluation stage By default, evaluation stage is also computing the cross-entropy and the perplexity over the evaluation dataset. Indeed the objective metrics used for evaluation can be costly to run or require some extra dependencies. Please refer to the [metrics documentation](./METRICS.md) for more details on the requirements for each metric. We provide an off-the-shelf configuration to enable running the objective metrics for audio generation in [config/solver/musicgen/evaluation/objective_eval](../config/solver/musicgen/evaluation/objective_eval.yaml). One can then activate evaluation the following way: ```shell # using the configuration dora run solver=musicgen/debug solver/musicgen/evaluation=objective_eval # specifying each of the fields, e.g. to activate KL computation dora run solver=musicgen/debug evaluate.metrics.kld=true ``` See [an example evaluation grid](../audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py). ### Generation stage The generation stage allows to generate samples conditionally and/or unconditionally and to perform audio continuation (from a prompt). We currently support greedy sampling (argmax), sampling from softmax with a given temperature, top-K and top-P (nucleus) sampling. The number of samples generated and the batch size used are controlled by the `dataset.generate` configuration while the other generation parameters are defined in `generate.lm`. ```shell # control sampling parameters dora run solver=musicgen/debug generate.lm.gen_duration=10 generate.lm.use_sampling=true generate.lm.top_k=15 ``` #### Listening to samples Note that generation happens automatically every 25 epochs. You can easily access and compare samples between models (as long as they are trained) on the same dataset using the MOS tool. For that first `pip install Flask gunicorn`. Then ``` gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - ``` And access the tool at [https://127.0.0.1:8895](https://127.0.0.1:8895). ### Playing with the model Once you have launched some experiments, you can easily get access to the Solver with the latest trained model using the following snippet. ```python from audiocraft.solvers.musicgen import MusicGen solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) solver.model solver.dataloaders ``` ### Importing / Exporting models We do not support currently loading a model from the Hugging Face implementation or exporting to it. If you want to export your model in a way that is compatible with `audiocraft.models.MusicGen` API, you can run: ```python from audiocraft.utils import export from audiocraft import train xp = train.main.get_xp_from_sig('SIG_OF_LM') export.export_lm(xp.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/state_dict.bin') # You also need to bundle the EnCodec model you used !! ## Case 1) you trained your own xp_encodec = train.main.get_xp_from_sig('SIG_OF_ENCODEC') export.export_encodec(xp_encodec.folder / 'checkpoint.th', '/checkpoints/my_audio_lm/compression_state_dict.bin') ## Case 2) you used a pretrained model. Give the name you used without the //pretrained/ prefix. ## This will actually not dump the actual model, simply a pointer to the right model to download. export.export_pretrained_compression_model('facebook/encodec_32khz', '/checkpoints/my_audio_lm/compression_state_dict.bin') ``` Now you can load your custom model with: ```python import audiocraft.models musicgen = audiocraft.models.MusicGen.get_pretrained('/checkpoints/my_audio_lm/') ``` ### Learn more Learn more about AudioCraft training pipelines in the [dedicated section](./TRAINING.md). ## FAQ #### I need help on Windows @FurkanGozukara made a complete tutorial for [AudioCraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4) #### I need help for running the demo on Colab Check [@camenduru tutorial on YouTube](https://www.youtube.com/watch?v=EGfxuTy9Eeo). #### What are top-k, top-p, temperature and classifier-free guidance? Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt). #### Should I use FSDP or autocast ? The two are mutually exclusive (because FSDP does autocast on its own). You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU. FSDP makes everything more complex but will free up some memory for the actual activations by sharding the optimizer state. ## Citation ``` @article{copet2023simple, title={Simple and Controllable Music Generation}, author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, year={2023}, journal={arXiv preprint arXiv:2306.05284}, } ``` ## License See license information in the [model card](../model_cards/MUSICGEN_MODEL_CARD.md). [arxiv]: https://arxiv.org/abs/2306.05284 [musicgen_samples]: https://ai.honu.io/papers/musicgen/ ================================================ FILE: docs/TRAINING.md ================================================ # AudioCraft training pipelines AudioCraft training pipelines are built on top of PyTorch as our core deep learning library and [Flashy](https://github.com/facebookresearch/flashy) as our training pipeline design library, and [Dora](https://github.com/facebookresearch/dora) as our experiment manager. AudioCraft training pipelines are designed to be research and experiment-friendly. ## Environment setup For the base installation, follow the instructions from the [README.md](../README.md). Below are some additional instructions for setting up environment to train new models. ### Team and cluster configuration In order to support multiple teams and clusters, AudioCraft uses an environment configuration. The team configuration allows to specify cluster-specific configurations (e.g. SLURM configuration), or convenient mapping of paths between the supported environments. Each team can have a yaml file under the [configuration folder](../config). To select a team set the `AUDIOCRAFT_TEAM` environment variable to a valid team name (e.g. `labs` or `default`): ```shell conda env config vars set AUDIOCRAFT_TEAM=default ``` Alternatively, you can add it to your `.bashrc`: ```shell export AUDIOCRAFT_TEAM=default ``` If not defined, the environment will default to the `default` team. The cluster is automatically detected, but it is also possible to override it by setting the `AUDIOCRAFT_CLUSTER` environment variable. Based on this team and cluster, the environment is then configured with: * The dora experiment outputs directory. * The available slurm partitions: categorized by global and team. * A shared reference directory: In order to facilitate sharing research models while remaining agnostic to the used compute cluster, we created the `//reference` symbol that can be used in YAML config to point to a defined reference folder containing shared checkpoints (e.g. baselines, models for evaluation...). **Important:** The default output dir for trained models and checkpoints is under `/tmp/`. This is suitable only for quick testing. If you are doing anything serious you MUST edit the file `default.yaml` and properly set the `dora_dir` entries. #### Overriding environment configurations You can set the following environmet variables to bypass the team's environment configuration: * `AUDIOCRAFT_CONFIG`: absolute path to a team config yaml file. * `AUDIOCRAFT_DORA_DIR`: absolute path to a custom dora directory. * `AUDIOCRAFT_REFERENCE_DIR`: absolute path to the shared reference directory. ## Training pipelines Each task supported in AudioCraft has its own training pipeline and dedicated solver. Learn more about solvers and key designs around AudioCraft training pipeline below. Please refer to the documentation of each task and model for specific information on a given task. ### Solvers The core training component in AudioCraft is the solver. A solver holds the definition of how to solve a given task: It implements the training pipeline logic, combining the datasets, model, optimization criterion and components and the full training loop. We refer the reader to [Flashy](https://github.com/facebookresearch/flashy) for core principles around solvers. AudioCraft proposes an initial solver, the `StandardSolver` that is used as the base implementation for downstream solvers. This standard solver provides a nice base management of logging, checkpoints loading/saving, xp restoration, etc. on top of the base Flashy implementation. In AudioCraft, we made the assumption that all tasks are following the same set of stages: train, valid, evaluate and generation, each relying on a dedicated dataset. Each solver is responsible for defining the task to solve and the associated stages of the training loop in order to leave the full ownership of the training pipeline to the researchers. This includes loading the datasets, building the model and optimisation components, registering them and defining the execution of each stage. To create a new solver for a given task, one should extend the StandardSolver and define each stage of the training loop. One can further customise its own solver starting from scratch instead of inheriting from the standard solver. ```python from . import base from .. import optim class MyNewSolver(base.StandardSolver): def __init__(self, cfg: omegaconf.DictConfig): super().__init__(cfg) # one can add custom attributes to the solver self.criterion = torch.nn.L1Loss() def best_metric(self): # here optionally specify which metric to use to keep track of best state return 'loss' def build_model(self): # here you can instantiate your models and optimization related objects # this method will be called by the StandardSolver init method self.model = ... # the self.cfg attribute contains the raw configuration self.optimizer = optim.build_optimizer(self.model.parameters(), self.cfg.optim) # don't forget to register the states you'd like to include in your checkpoints! self.register_stateful('model', 'optimizer') # keep the model best state based on the best value achieved at validation for the given best_metric self.register_best('model') # if you want to add EMA around the model self.register_ema('model') def build_dataloaders(self): # here you can instantiate your dataloaders # this method will be called by the StandardSolver init method self.dataloaders = ... ... # For both train and valid stages, the StandardSolver relies on # a share common_train_valid implementation that is in charge of # accessing the appropriate loader, iterate over the data up to # the specified number of updates_per_epoch, run the ``run_step`` # function that you need to implement to specify the behavior # and finally update the EMA and collect the metrics properly. @abstractmethod def run_step(self, idx: int, batch: tp.Any, metrics: dict): """Perform one training or valid step on a given batch. """ ... # provide your implementation of the solver over a batch def train(self): """Train stage. """ return self.common_train_valid('train') def valid(self): """Valid stage. """ return self.common_train_valid('valid') @abstractmethod def evaluate(self): """Evaluate stage. """ ... # provide your implementation here! @abstractmethod def generate(self): """Generate stage. """ ... # provide your implementation here! ``` ### About Epochs AudioCraft Solvers uses the concept of Epoch. One epoch doesn't necessarily mean one pass over the entire dataset, but instead represent the smallest amount of computation that we want to work with before checkpointing. Typically, we find that having an Epoch time around 30min is ideal both in terms of safety (checkpointing often enough) and getting updates often enough. One Epoch is at least a `train` stage that lasts for `optim.updates_per_epoch` (2000 by default), and a `valid` stage. You can control how long the valid stage takes with `dataset.valid.num_samples`. Other stages (`evaluate`, `generate`) will only happen every X epochs, as given by `evaluate.every` and `generate.every`). ### Models In AudioCraft, a model is a container object that wraps one or more torch modules together with potential processing logic to use in a solver. For example, a model would wrap an encoder module, a quantisation bottleneck module, a decoder and some tensor processing logic. Each of the previous components can be considered as a small « model unit » on its own but the container model is a practical component to manipulate and train a set of modules together. ### Datasets See the [dedicated documentation on datasets](./DATASETS.md). ### Metrics See the [dedicated documentation on metrics](./METRICS.md). ### Conditioners AudioCraft language models can be conditioned in various ways and the codebase offers a modular implementation of different conditioners that can be potentially combined together. Learn more in the [dedicated documentation on conditioning](./CONDITIONING.md). ### Configuration AudioCraft's configuration is defined in yaml files and the framework relies on [hydra](https://hydra.cc/docs/intro/) and [omegaconf](https://omegaconf.readthedocs.io/) to parse and manipulate the configuration through Dora. ##### :warning: Important considerations around configurations Our configuration management relies on Hydra and the concept of group configs to structure and compose configurations. Updating the root default configuration files will then have an impact on all solvers and tasks. **One should never change the default configuration files. Instead they should use Hydra config groups in order to store custom configuration.** Once this configuration is created and used for running experiments, you should not edit it anymore. Note that as we are using Dora as our experiment manager, all our experiment tracking is based on signatures computed from delta between configurations. **One must therefore ensure backward compatibilty of the configuration at all time.** See [Dora's README](https://github.com/facebookresearch/dora) and the [section below introduction Dora](#running-experiments-with-dora). ##### Configuration structure The configuration is organized in config groups: * `conditioner`: default values for conditioning modules. * `dset`: contains all data source related information (paths to manifest files and metadata for a given dataset). * `model`: contains configuration for each model defined in AudioCraft and configurations for different variants of models. * `solver`: contains the default configuration for each solver as well as configuration for each solver task, combining all the above components. * `teams`: contains the cluster configuration per teams. See environment setup for more details. The `config.yaml` file is the main configuration that composes the above groups and contains default configuration for AudioCraft. ##### Solver's core configuration structure The core configuration structure shared across solver is available in `solvers/default.yaml`. ##### Other configuration modules AudioCraft configuration contains the different setups we used for our research and publications. ## Running experiments with Dora ### Launching jobs Try launching jobs for different tasks locally with dora run: ```shell # run compression task with lightweight encodec dora run solver=compression/debug ``` Most of the time, the jobs are launched through dora grids, for example: ```shell # run compression task through debug grid dora grid compression.debug ``` Learn more about running experiments with Dora below. ### A small introduction to Dora [Dora](https://github.com/facebookresearch/dora) is the experiment manager tool used in AudioCraft. Check out the README to learn how Dora works. Here is a quick summary of what to know: * An XP is a unique set of hyper-parameters with a given signature. The signature is a hash of those hyper-parameters. We always refer to an XP with its signature, e.g. 9357e12e. We will see after that one can retrieve the hyper-params and re-rerun it in a single command. * In fact, the hash is defined as a delta between the base config and the one obtained with the config overrides you passed from the command line. This means you must never change the `conf/**.yaml` files directly., except for editing things like paths. Changing the default values in the config files means the XP signature won't reflect that change, and wrong checkpoints might be reused. I know, this is annoying, but the reason is that otherwise, any change to the config file would mean that all XPs ran so far would see their signature change. #### Dora commands ```shell dora info -f 81de367c # this will show the hyper-parameter used by a specific XP. # Be careful some overrides might present twice, and the right most one # will give you the right value for it. dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c. # `-d` is for distributed, it will use all available GPUs. dora run -d -f 81de367c dataset.batch_size=32 # start from the config of XP 81de367c but change some hyper-params. # This will give you a new XP with a new signature (e.g. 3fe9c332). dora info -f SIG -t # will tail the log (if the XP has scheduled). # if you need to access the logs of the process for rank > 0, in particular because a crash didn't happen in the main # process, then use `dora info -f SIG` to get the main log name (finished into something like `/5037674_0_0_log.out`) # and worker K can accessed as `/5037674_0_{K}_log.out`. # This is only for scheduled jobs, for local distributed runs with `-d`, then you should go into the XP folder, # and look for `worker_{K}.log` logs. ``` An XP runs from a specific folder based on its signature, under the `//experiments/audiocraft/outputs/` folder. You can safely interrupt a training and resume it, it will reuse any existing checkpoint, as it will reuse the same folder. If you made some change to the code and need to ignore a previous checkpoint you can use `dora run --clear [RUN ARGS]`. If you have a Slurm cluster, you can also use the dora grid command, e.g. ```shell # run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py` dora grid my_grid_folder.my_grid_name # Run the following will simply display the grid and also initialized the Dora experiments database. # You can then simply refer to a config using its signature (e.g. as `dora run -f SIG`). dora grid my_grid_folder.my_grid_name --dry_run --init ``` Please refer to the [Dora documentation](https://github.com/facebookresearch/dora) for more information. #### Clearing up past experiments ```shell # This will cancel all the XPs and delete their folder and checkpoints. # It will then reschedule them starting from scratch. dora grid my_grid_folder.my_grid_name --clear # The following will delete the folder and checkpoint for a single XP, # and then run it afresh. dora run [-f BASE_SIG] [ARGS] --clear ``` ================================================ FILE: egs/example/data.jsonl ================================================ {"path": "dataset/example/electro_1.mp3", "duration": 15.024, "sample_rate": 48000, "amplitude": null, "weight": null, "info_path": null} {"path": "dataset/example/electro_2.mp3", "duration": 20.035918367346937, "sample_rate": 44100, "amplitude": null, "weight": null, "info_path": null} ================================================ FILE: model_cards/AUDIOGEN_MODEL_CARD.md ================================================ # AudioGen Model Card ## Model details **Organization developing the model:** The FAIR team of Meta AI. **Model date:** This version of AudioGen was trained between July 2023 and August 2023. **Model version:** This is version 2 of the model, not to be confused with the original AudioGen model published in ["AudioGen: Textually Guided Audio Generation"][audiogen]. In this version (v2), AudioGen was trained on the same data, but with some other differences: 1. This model was trained on 10 seconds (vs. 5 seconds in v1). 2. The discrete representation used under the hood is extracted using a retrained EnCodec model on the environmental sound data, following the EnCodec setup detailed in the ["Simple and Controllable Music Generation" paper][musicgen]. 3. No audio mixing augmentations. **Model type:** AudioGen consists of an EnCodec model for audio tokenization, and an auto-regressive language model based on the transformer architecture for audio modeling. The released model has 1.5B parameters. **Paper or resource for more information:** More information can be found in the paper [AudioGen: Textually Guided Audio Generation](https://arxiv.org/abs/2209.15352). **Citation details:** See [AudioGen paper][audiogen] **License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. **Where to send questions or comments about the model:** Questions and comments about AudioGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. ## Intended use **Primary intended use:** The primary use of AudioGen is research on AI-based audio generation, including: - Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science - Generation of sound guided by text to understand current abilities of generative AI models by machine learning amateurs **Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. **Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate audio pieces that create hostile or alienating environments for people. This includes generating audio that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. ## Metrics **Models performance measures:** We used the following objective measure to evaluate the model on a standard audio benchmark: - Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) - Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: - Overall quality of the audio samples; - Text relevance to the provided text input; More details on performance measures and human studies can be found in the paper. **Decision thresholds:** Not applicable. ## Evaluation datasets The model was evaluated on the [AudioCaps benchmark](https://audiocaps.github.io/). ## Training datasets The model was trained on the following data sources: a subset of AudioSet (Gemmeke et al., 2017), [BBC sound effects](https://sound-effects.bbcrewind.co.uk/), AudioCaps (Kim et al., 2019), Clotho v2 (Drossos et al., 2020), VGG-Sound (Chen et al., 2020), FSD50K (Fonseca et al., 2021), [Free To Use Sounds](https://www.freetousesounds.com/all-in-one-bundle/), [Sonniss Game Effects](https://sonniss.com/gameaudiogdc), [WeSoundEffects](https://wesoundeffects.com/we-sound-effects-bundle-2020/), [Paramount Motion - Odeon Cinematic Sound Effects](https://www.paramountmotion.com/odeon-sound-effects). ## Evaluation results Below are the objective metrics obtained with the released model on AudioCaps (consisting of 10-second long samples). Note that the model differs from the original AudioGen model introduced in the paper, hence the difference in the metrics. | Model | Frechet Audio Distance | KLD | Text consistency | |---|---|---|---| | facebook/audiogen-medium | 1.77 | 1.58 | 0.30 | More information can be found in the paper [AudioGen: Textually Guided Audio Generation][audiogen], in the Experiments section. ## Limitations and biases **Limitations:** - The model is not able to generate realistic vocals. - The model has been trained with English descriptions and will not perform as well in other languages. - It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. **Biases:** The datasets used for training may be lacking of diversity and are not representative of all possible sound events. The generated samples from the model will reflect the biases from the training data. **Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. **Use cases:** Users must be aware of the biases, limitations and risks of the model. AudioGen is a model developed for artificial intelligence research on audio generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. [musicgen]: https://arxiv.org/abs/2306.05284 [audiogen]: https://arxiv.org/abs/2209.15352 ================================================ FILE: model_cards/MUSICGEN_MODEL_CARD.md ================================================ # MusicGen Model Card ## Model details **Organization developing the model:** The FAIR team of Meta AI. **Model date:** MusicGen was trained between April 2023 and May 2023. **Model version:** This is the version 1 of the model. **Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation. **Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv]. **Citation details:** See [our paper][arxiv] **License:** Code is released under MIT, model weights are released under CC-BY-NC 4.0. **Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [GitHub repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. ## Intended use **Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including: - Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science - Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs **Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. **Out-of-scope use cases:** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. ## Metrics **Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark: - Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) - Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) - CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: - Overall quality of the music samples; - Text relevance to the provided text input; - Adherence to the melody for melody-guided music generation. More details on performance measures and human studies can be found in the paper. **Decision thresholds:** Not applicable. ## Evaluation datasets The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set. ## Training datasets The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing. ## Evaluation results Below are the objective metrics obtained on MusicCaps with the released model. Note that for the publicly released models, we had all the datasets go through a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs), in order to keep only the instrumental part. This explains the difference in objective metrics with the models used in the paper. | Model | Frechet Audio Distance | KLD | Text Consistency | Chroma Cosine Similarity | |---|---|---|---|---| | facebook/musicgen-small | 4.88 | 1.42 | 0.27 | - | | facebook/musicgen-medium | 5.14 | 1.38 | 0.28 | - | | facebook/musicgen-large | 5.48 | 1.37 | 0.28 | - | | facebook/musicgen-melody | 4.93 | 1.41 | 0.27 | 0.44 | More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Results section. ## Limitations and biases **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model. **Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). **Limitations:** - The model is not able to generate realistic vocals. - The model has been trained with English descriptions and will not perform as well in other languages. - The model does not perform equally well for all music styles and cultures. - The model sometimes generates end of songs, collapsing to silence. - It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. **Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive. **Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. **Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. [arxiv]: https://arxiv.org/abs/2306.05284 ================================================ FILE: models/Put your models here.txt ================================================ nothing here ================================================ FILE: mypy.ini ================================================ [mypy] [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*] ignore_missing_imports = True ================================================ FILE: requirements.txt ================================================ # please make sure you have already a pytorch install that is cuda enabled! av einops flashy>=0.0.1 hydra-core>=1.1 hydra_colorlog julius num2words numpy sentencepiece spacy==3.5.2 torch>=2.0.0 torchaudio>=2.0.0 huggingface_hub tqdm transformers>=4.31.0 # need Encodec there. xformers demucs librosa gradio torchmetrics encodec pytaglib ================================================ FILE: scripts/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: scripts/mos.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ To run this script, from the root of the repo. Make sure to have Flask installed FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567 # or if you have gunicorn gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - """ from collections import defaultdict from functools import wraps from hashlib import sha1 import json import math from pathlib import Path import random import typing as tp from flask import Flask, redirect, render_template, request, session, url_for from audiocraft import train from audiocraft.utils.samples.manager import get_samples_for_xps SAMPLES_PER_PAGE = 8 MAX_RATING = 5 storage = Path(train.main.dora.dir / 'mos_storage') storage.mkdir(exist_ok=True) surveys = storage / 'surveys' surveys.mkdir(exist_ok=True) magma_root = Path(train.__file__).parent.parent app = Flask('mos', static_folder=str(magma_root / 'scripts/static'), template_folder=str(magma_root / 'scripts/templates')) app.secret_key = b'audiocraft makes the best songs' def normalize_path(path: Path): """Just to make path a bit nicer, make them relative to the Dora root dir. """ path = path.resolve() dora_dir = train.main.dora.dir.resolve() / 'xps' return path.relative_to(dora_dir) def get_full_path(normalized_path: Path): """Revert `normalize_path`. """ return train.main.dora.dir.resolve() / 'xps' / normalized_path def get_signature(xps: tp.List[str]): """Return a signature for a list of XP signatures. """ return sha1(json.dumps(xps).encode()).hexdigest()[:10] def ensure_logged(func): """Ensure user is logged in. """ @wraps(func) def _wrapped(*args, **kwargs): user = session.get('user') if user is None: return redirect(url_for('login', redirect_to=request.url)) return func(*args, **kwargs) return _wrapped @app.route('/login', methods=['GET', 'POST']) def login(): """Login user if not already, then redirect. """ user = session.get('user') if user is None: error = None if request.method == 'POST': user = request.form['user'] if not user: error = 'User cannot be empty' if user is None or error: return render_template('login.html', error=error) assert user session['user'] = user redirect_to = request.args.get('redirect_to') if redirect_to is None: redirect_to = url_for('index') return redirect(redirect_to) @app.route('/', methods=['GET', 'POST']) @ensure_logged def index(): """Offer to create a new study. """ errors = [] if request.method == 'POST': xps_or_grids = [part.strip() for part in request.form['xps'].split()] xps = set() for xp_or_grid in xps_or_grids: xp_path = train.main.dora.dir / 'xps' / xp_or_grid if xp_path.exists(): xps.add(xp_or_grid) continue grid_path = train.main.dora.dir / 'grids' / xp_or_grid if grid_path.exists(): for child in grid_path.iterdir(): if child.is_symlink(): xps.add(child.name) continue errors.append(f'{xp_or_grid} is neither an XP nor a grid!') assert xps or errors blind = 'true' if request.form.get('blind') == 'on' else 'false' xps = list(xps) if not errors: signature = get_signature(xps) manifest = { 'xps': xps, } survey_path = surveys / signature survey_path.mkdir(exist_ok=True) with open(survey_path / 'manifest.json', 'w') as f: json.dump(manifest, f, indent=2) return redirect(url_for('survey', blind=blind, signature=signature)) return render_template('index.html', errors=errors) @app.route('/survey/', methods=['GET', 'POST']) @ensure_logged def survey(signature): success = request.args.get('success', False) seed = int(request.args.get('seed', 4321)) blind = request.args.get('blind', 'false') in ['true', 'on', 'True'] exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True'] exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True'] max_epoch = int(request.args.get('max_epoch', '-1')) survey_path = surveys / signature assert survey_path.exists(), survey_path user = session['user'] result_folder = survey_path / 'results' result_folder.mkdir(exist_ok=True) result_file = result_folder / f'{user}_{seed}.json' with open(survey_path / 'manifest.json') as f: manifest = json.load(f) xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']] names, ref_name = train.main.get_names(xps) samples_kwargs = { 'exclude_prompted': exclude_prompted, 'exclude_unprompted': exclude_unprompted, 'max_epoch': max_epoch, } matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) # fetch latest epoch models_by_id = { id: [{ 'xp': xps[idx], 'xp_name': names[idx], 'model_id': f'{xps[idx].sig}-{sample.id}', 'sample': sample, 'is_prompted': sample.prompt is not None, 'errors': [], } for idx, sample in enumerate(samples)] for id, samples in matched_samples.items() } experiments = [ {'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch} for idx, xp in enumerate(xps) ] keys = list(matched_samples.keys()) keys.sort() rng = random.Random(seed) rng.shuffle(keys) model_ids = keys[:SAMPLES_PER_PAGE] if blind: for key in model_ids: rng.shuffle(models_by_id[key]) ok = True if request.method == 'POST': all_samples_results = [] for id in model_ids: models = models_by_id[id] result = { 'id': id, 'is_prompted': models[0]['is_prompted'], 'models': {} } all_samples_results.append(result) for model in models: rating = request.form[model['model_id']] if rating: rating = int(rating) assert rating <= MAX_RATING and rating >= 1 result['models'][model['xp'].sig] = rating model['rating'] = rating else: ok = False model['errors'].append('Please rate this model.') if ok: result = { 'results': all_samples_results, 'seed': seed, 'user': user, 'blind': blind, 'exclude_prompted': exclude_prompted, 'exclude_unprompted': exclude_unprompted, } print(result) with open(result_file, 'w') as f: json.dump(result, f) seed = seed + 1 return redirect(url_for( 'survey', signature=signature, blind=blind, seed=seed, exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, success=True)) ratings = list(range(1, MAX_RATING + 1)) return render_template( 'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success, exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[], ref_name=ref_name, already_filled=result_file.exists()) @app.route('/audio/') def audio(path: str): full_path = Path('/') / path assert full_path.suffix in [".mp3", ".wav"] return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'} def mean(x): return sum(x) / len(x) def std(x): m = mean(x) return math.sqrt(sum((i - m)**2 for i in x) / len(x)) @app.route('/results/') @ensure_logged def results(signature): survey_path = surveys / signature assert survey_path.exists(), survey_path result_folder = survey_path / 'results' result_folder.mkdir(exist_ok=True) # ratings per model, then per user. ratings_per_model = defaultdict(list) users = [] for result_file in result_folder.iterdir(): if result_file.suffix != '.json': continue with open(result_file) as f: results = json.load(f) users.append(results['user']) for result in results['results']: for sig, rating in result['models'].items(): ratings_per_model[sig].append(rating) fmt = '{:.2f}' models = [] for model in sorted(ratings_per_model.keys()): ratings = ratings_per_model[model] models.append({ 'sig': model, 'samples': len(ratings), 'mean_rating': fmt.format(mean(ratings)), # the value 1.96 was probably chosen to achieve some # confidence interval assuming gaussianity. 'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5), }) return render_template('results.html', signature=signature, models=models, users=users) ================================================ FILE: scripts/resample_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Resampling script. """ import argparse from pathlib import Path import shutil import typing as tp import submitit import tqdm from audiocraft.data.audio import audio_read, audio_write from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files from audiocraft.data.audio_utils import convert_audio from audiocraft.environment import AudioCraftEnvironment def read_txt_files(path: tp.Union[str, Path]): with open(args.files_path) as f: lines = [line.rstrip() for line in f] print(f"Read {len(lines)} in .txt") lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']] print(f"Filtered and keep {len(lines)} from .txt") return lines def read_egs_files(path: tp.Union[str, Path]): path = Path(path) if path.is_dir(): if (path / 'data.jsonl').exists(): path = path / 'data.jsonl' elif (path / 'data.jsonl.gz').exists(): path = path / 'data.jsonl.gz' else: raise ValueError("Don't know where to read metadata from in the dir. " "Expecting either a data.jsonl or data.jsonl.gz file but none found.") meta = load_audio_meta(path) return [m.path for m in meta] def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None): if task_index is None: env = submitit.JobEnvironment() task_index = env.global_rank shard_index = node_index * args.tasks_per_node + task_index if args.files_path is None: lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)] else: files_path = Path(args.files_path) if files_path.suffix == '.txt': print(f"Reading file list from .txt file: {args.files_path}") lines = read_txt_files(args.files_path) else: print(f"Reading file list from egs: {args.files_path}") lines = read_egs_files(args.files_path) total_files = len(lines) print( f"Total of {total_files} processed with {n_shards} shards. " + f"Current idx = {shard_index} -> {total_files // n_shards} files to process" ) for idx, line in tqdm.tqdm(enumerate(lines)): # skip if not part of this shard if idx % n_shards != shard_index: continue path = str(AudioCraftEnvironment.apply_dataset_mappers(line)) root_path = str(args.root_path) if not root_path.endswith('/'): root_path += '/' assert path.startswith(str(root_path)), \ f"Mismatch between path and provided root: {path} VS {root_path}" try: metadata_path = Path(path).with_suffix('.json') out_path = args.out_path / path[len(root_path):] out_metadata_path = out_path.with_suffix('.json') out_done_token = out_path.with_suffix('.done') # don't reprocess existing files if out_done_token.exists(): continue print(idx, out_path, path) mix, sr = audio_read(path) mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0) # enforce simple stereo out_channels = mix_channels if out_channels > 2: print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels") out_channels = 2 out_sr = args.sample_rate if args.sample_rate is not None else sr out_wav = convert_audio(mix, sr, out_sr, out_channels) audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr, format=args.format, normalize=False, strategy='clip') if metadata_path.exists(): shutil.copy(metadata_path, out_metadata_path) else: print(f"No metadata found at {str(metadata_path)}") out_done_token.touch() except Exception as e: print(f"Error processing file line: {line}, {e}") if __name__ == '__main__': parser = argparse.ArgumentParser(description="Resample dataset with SLURM.") parser.add_argument( "--log_root", type=Path, default=Path.home() / 'tmp' / 'resample_logs', ) parser.add_argument( "--files_path", type=Path, help="List of files to process, either .txt (one file per line) or a jsonl[.gz].", ) parser.add_argument( "--root_path", type=Path, required=True, help="When rewriting paths, this will be the prefix to remove.", ) parser.add_argument( "--out_path", type=Path, required=True, help="When rewriting paths, `root_path` will be replaced by this.", ) parser.add_argument("--xp_name", type=str, default="shutterstock") parser.add_argument( "--nodes", type=int, default=4, ) parser.add_argument( "--tasks_per_node", type=int, default=20, ) parser.add_argument( "--cpus_per_task", type=int, default=4, ) parser.add_argument( "--memory_gb", type=int, help="Memory in GB." ) parser.add_argument( "--format", type=str, default="wav", ) parser.add_argument( "--sample_rate", type=int, default=32000, ) parser.add_argument( "--channels", type=int, ) parser.add_argument( "--partition", default='learnfair', ) parser.add_argument("--qos") parser.add_argument("--account") parser.add_argument("--timeout", type=int, default=4320) parser.add_argument('--debug', action='store_true', help='debug mode (local run)') args = parser.parse_args() n_shards = args.tasks_per_node * args.nodes if args.files_path is None: print("Warning: --files_path not provided, not recommended when processing more than 10k files.") if args.debug: print("Debugging mode") process_dataset(args, n_shards=n_shards, node_index=0, task_index=0) else: log_folder = Path(args.log_root) / args.xp_name / '%j' print(f"Logging to: {log_folder}") log_folder.parent.mkdir(parents=True, exist_ok=True) executor = submitit.AutoExecutor(folder=str(log_folder)) if args.qos: executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account) else: executor.update_parameters(slurm_partition=args.partition) executor.update_parameters( slurm_job_name=args.xp_name, timeout_min=args.timeout, cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1) if args.memory_gb: executor.update_parameters(mem=f'{args.memory_gb}GB') jobs = [] with executor.batch(): for node_index in range(args.nodes): job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index) jobs.append(job) for job in jobs: print(f"Waiting on job {job.job_id}") job.results() ================================================ FILE: scripts/static/style.css ================================================ body { background-color: #fbfbfb; margin: 0; } select, input { font-size: 1em; max-width: 100%; } .xp_name { font-family: monospace; } .simple_form { background-color: #dddddd; padding: 1em; margin: 0.5em; } textarea { margin-top: 0.5em; margin-bottom: 0.5em; } .rating { background-color: grey; padding-top: 5px; padding-bottom: 5px; padding-left: 8px; padding-right: 8px; margin-right: 2px; cursor:pointer; } .rating_selected { background-color: purple; } .content { font-family: sans-serif; background-color: #f6f6f6; padding: 40px; margin: 0 auto; max-width: 1000px; } .track label { padding-top: 10px; padding-bottom: 10px; } .track { padding: 15px; margin: 5px; background-color: #c8c8c8; } .submit-big { width:400px; height:30px; font-size: 20px; } .error { color: red; } .ratings { margin-left: 10px; } .important { font-weight: bold; } .survey { margin-bottom: 100px; } .success { color: #25901b; font-weight: bold; } .warning { color: #8a1f19; font-weight: bold; } .track>section { display: flex; align-items: center; } .prompt { display: flex; align-items: center; } .track>section>div { padding-left: 10px; } audio { max-width: 280px; max-height: 40px; margin-left: 10px; margin-right: 10px; } .special { font-weight: bold; color: #2c2c2c; } ================================================ FILE: scripts/templates/base.html ================================================ {% block head %} AudioCraft — MOS {% endblock %}

AudioCraft — MOS

{% block content %}{% endblock %}
================================================ FILE: scripts/templates/index.html ================================================ {% extends "base.html" %} {% block content %}

Welcome {{session['user']}} to the internal MOS assistant for AudioCraft. You can create custom surveys between your models, that you can evaluate yourself, or with the help of your teammates, by simply sharing a link!

{% for error in errors %}

{{error}}

{% endfor %}

{% endblock %} ================================================ FILE: scripts/templates/login.html ================================================ {% extends "base.html" %} {% block content %}

You must identify yourself first! We use a highly secured protocol where you just decide your username, and that's it. No password, no encryption, just pure trust.

{% if error %}

{{error}}

{% endif %} {% endblock %} ================================================ FILE: scripts/templates/results.html ================================================ {% extends "base.html" %} {% block content %}

Results for survey #{{signature}}

Checkout the survey page for details on the models.

The following users voted: {% for user in users %} {{user}} {% endfor %} {% for model in models %}

{{model['sig']}} ({{model['samples']}} samples)

Ratings: {{model['mean_rating']}} ± {{model['std_rating']}}

{% endfor %} {% endblock %} ================================================ FILE: scripts/templates/survey.html ================================================ {% extends "base.html" %} {% block content %}

Survey #{{signature}}

{% if success %}

Your ratings have been saved! You have been moved to the next random seed, if you want to keep rating more samples.

{% endif %} {% if already_filled %}

You already rated those samples in the past, filling this form will override your previous ratings.

{% endif %}

Welcome {{session['user']}} to the survey #{{signature}}. Go to the result page to check the results. Go to the home page to start a new survey.

{% for error in errors %}

{{error}}

{% endfor %} {% if not blind %}

Base config is: {{ref_name}}

The following experiments are compared:

    {% for experiment in experiments %}
  • {{experiment.xp.sig}} ({{experiment.epoch}} epochs): {{experiment.name}}
  • {% endfor %}
{% else %}

This is a blind experiment, the order of all XPs is shuffled with every sample.

{% endif %}

The current random seed is {{seed}}. You can change it with the following form, and also update blind/non blind.

Samples

{% for id in model_ids %}

{{id}}

{% for model in models_by_id[id] %} {% if loop.index == 1 and model.is_prompted %}

Prompt is

Ground truth is

{% endif %} {% for err in model['errors'] %}

{{err}}

{% endfor %}
{% if not blind %}

{{model.xp.sig}}:

{% endif %}

Rating:

{% for rating in ratings %} {{rating}} {% endfor %}

{% endfor %}

{% endfor %}
{% endblock %} ================================================ FILE: setup.cfg ================================================ [pep8] max-line-length = 120 [flake8] max-line-length = 120 [coverage:report] include = audiocraft/* omit = audiocraft/environment.py audiocraft/solvers/* audiocraft/utils/* audiocraft/*/loaders.py audiocraft/*/builders.py ================================================ FILE: setup.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path from setuptools import setup, find_packages NAME = 'audiocraft' DESCRIPTION = 'Audio generation research library for PyTorch' URL = 'https://github.com/facebookresearch/audiocraft' AUTHOR = 'FAIR Speech & Audio' EMAIL = 'defossez@meta.com, jadecopet@meta.com' REQUIRES_PYTHON = '>=3.8.0' for line in open('audiocraft/__init__.py'): line = line.strip() if '__version__' in line: context = {} exec(line, context) VERSION = context['__version__'] HERE = Path(__file__).parent try: with open(HERE / "README.md", encoding='utf-8') as f: long_description = '\n' + f.read() except FileNotFoundError: long_description = DESCRIPTION REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')] setup( name=NAME, version=VERSION, description=DESCRIPTION, author_email=EMAIL, long_description=long_description, long_description_content_type='text/markdown', author=AUTHOR, url=URL, python_requires=REQUIRES_PYTHON, install_requires=REQUIRED, extras_require={ 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'], }, packages=find_packages(), package_data={'audiocraft': ['py.typed']}, include_package_data=True, license='MIT License', classifiers=[ # Trove classifiers # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 'License :: OSI Approved :: MIT License', 'Topic :: Multimedia :: Sound/Audio', 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], ) ================================================ FILE: tests/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: tests/adversarial/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: tests/adversarial/test_discriminators.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import torch from audiocraft.adversarial.discriminators import ( MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiScaleSTFTDiscriminator ) class TestMultiPeriodDiscriminator: def test_mpd_discriminator(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) periods = [1, 2, 3] mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C) logits, fmaps = mpd(t0) assert len(logits) == len(periods) assert len(fmaps) == len(periods) assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) class TestMultiScaleDiscriminator: def test_msd_discriminator(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) scale_norms = ['weight_norm', 'weight_norm'] msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C) logits, fmaps = msd(t0) assert len(logits) == len(scale_norms) assert len(fmaps) == len(scale_norms) assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits]) assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) class TestMultiScaleStftDiscriminator: def test_msstftd_discriminator(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) n_filters = 4 n_ffts = [128, 256, 64] hop_lengths = [32, 64, 16] win_lengths = [128, 256, 64] msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths, win_lengths=win_lengths, in_channels=C) logits, fmaps = msstftd(t0) assert len(logits) == len(n_ffts) assert len(fmaps) == len(n_ffts) assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) ================================================ FILE: tests/adversarial/test_losses.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import pytest import random import torch from audiocraft.adversarial import ( AdversarialLoss, get_adv_criterion, get_real_criterion, get_fake_criterion, FeatureMatchingLoss, MultiScaleDiscriminator, ) class TestAdversarialLoss: def test_adversarial_single_multidiscriminator(self): adv = MultiScaleDiscriminator() optimizer = torch.optim.Adam( adv.parameters(), lr=1e-4, ) loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake) B, C, T = 4, 1, random.randint(1000, 5000) real = torch.randn(B, C, T) fake = torch.randn(B, C, T) disc_loss = adv_loss.train_adv(fake, real) assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float) loss, loss_feat = adv_loss(fake, real) assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) # we did not specify feature loss assert loss_feat.item() == 0. def test_adversarial_feat_loss(self): adv = MultiScaleDiscriminator() optimizer = torch.optim.Adam( adv.parameters(), lr=1e-4, ) loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse') feat_loss = FeatureMatchingLoss() adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss) B, C, T = 4, 1, random.randint(1000, 5000) real = torch.randn(B, C, T) fake = torch.randn(B, C, T) loss, loss_feat = adv_loss(fake, real) assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float) assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float) class TestGeneratorAdversarialLoss: def test_hinge_generator_adv_loss(self): adv_loss = get_adv_criterion(loss_type='hinge') t0 = torch.randn(1, 2, 0) t1 = torch.FloatTensor([1.0, 2.0, 3.0]) assert adv_loss(t0).item() == 0.0 assert adv_loss(t1).item() == -2.0 def test_mse_generator_adv_loss(self): adv_loss = get_adv_criterion(loss_type='mse') t0 = torch.randn(1, 2, 0) t1 = torch.FloatTensor([1.0, 1.0, 1.0]) t2 = torch.FloatTensor([2.0, 5.0, 5.0]) assert adv_loss(t0).item() == 0.0 assert adv_loss(t1).item() == 0.0 assert adv_loss(t2).item() == 11.0 class TestDiscriminatorAdversarialLoss: def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor): disc_loss_real = get_real_criterion(loss_type) disc_loss_fake = get_fake_criterion(loss_type) loss = disc_loss_fake(fake) + disc_loss_real(real) return loss def test_hinge_discriminator_adv_loss(self): loss_type = 'hinge' t0 = torch.FloatTensor([0.0, 0.0, 0.0]) t1 = torch.FloatTensor([1.0, 2.0, 3.0]) assert self._disc_loss(loss_type, t0, t0).item() == 2.0 assert self._disc_loss(loss_type, t1, t1).item() == 3.0 def test_mse_discriminator_adv_loss(self): loss_type = 'mse' t0 = torch.FloatTensor([0.0, 0.0, 0.0]) t1 = torch.FloatTensor([1.0, 1.0, 1.0]) assert self._disc_loss(loss_type, t0, t0).item() == 1.0 assert self._disc_loss(loss_type, t1, t0).item() == 2.0 class TestFeatureMatchingLoss: def test_features_matching_loss_base(self): ft_matching_loss = FeatureMatchingLoss() length = random.randrange(1, 100_000) t1 = torch.randn(1, 2, length) loss = ft_matching_loss([t1], [t1]) assert isinstance(loss, torch.Tensor) assert loss.item() == 0.0 def test_features_matching_loss_raises_exception(self): ft_matching_loss = FeatureMatchingLoss() length = random.randrange(1, 100_000) t1 = torch.randn(1, 2, length) t2 = torch.randn(1, 2, length + 1) with pytest.raises(AssertionError): ft_matching_loss([], []) with pytest.raises(AssertionError): ft_matching_loss([t1], [t1, t1]) with pytest.raises(AssertionError): ft_matching_loss([t1], [t2]) def test_features_matching_loss_output(self): loss_nonorm = FeatureMatchingLoss(normalize=False) loss_layer_normed = FeatureMatchingLoss(normalize=True) length = random.randrange(1, 100_000) t1 = torch.randn(1, 2, length) t2 = torch.randn(1, 2, length) assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0 assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0 t3 = torch.FloatTensor([1.0, 2.0, 3.0]) t4 = torch.FloatTensor([2.0, 10.0, 3.0]) assert loss_nonorm([t3], [t4]).item() == 3.0 assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0 assert loss_layer_normed([t3], [t4]).item() == 3.0 assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0 ================================================ FILE: tests/common_utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # flake8: noqa from .temp_utils import TempDirMixin from .wav_utils import get_batch_white_noise, get_white_noise, save_wav ================================================ FILE: tests/common_utils/temp_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import tempfile class TempDirMixin: """Mixin to provide easy access to temp dir. """ temp_dir_ = None @classmethod def get_base_temp_dir(cls): # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory. # this is handy for debugging. key = "AUDIOCRAFT_TEST_DIR" if key in os.environ: return os.environ[key] if cls.temp_dir_ is None: cls.temp_dir_ = tempfile.TemporaryDirectory() return cls.temp_dir_.name @classmethod def tearDownClass(cls): if cls.temp_dir_ is not None: try: cls.temp_dir_.cleanup() cls.temp_dir_ = None except PermissionError: # On Windows there is a know issue with `shutil.rmtree`, # which fails intermittently. # https://github.com/python/cpython/issues/74168 # Following the above thread, we ignore it. pass super().tearDownClass() @property def id(self): return self.__class__.__name__ def get_temp_path(self, *paths): temp_dir = os.path.join(self.get_base_temp_dir(), self.id) path = os.path.join(temp_dir, *paths) os.makedirs(os.path.dirname(path), exist_ok=True) return path def get_temp_dir(self, *paths): temp_dir = os.path.join(self.get_base_temp_dir(), self.id) path = os.path.join(temp_dir, *paths) os.makedirs(path, exist_ok=True) return path ================================================ FILE: tests/common_utils/wav_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path import typing as tp import torch import torchaudio def get_white_noise(chs: int = 1, num_frames: int = 1): wav = torch.randn(chs, num_frames) return wav def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): wav = torch.randn(bs, chs, num_frames) return wav def save_wav(path: str, wav: torch.Tensor, sample_rate: int): fp = Path(path) kwargs: tp.Dict[str, tp.Any] = {} if fp.suffix == '.wav': kwargs['encoding'] = 'PCM_S' kwargs['bits_per_sample'] = 16 elif fp.suffix == '.mp3': kwargs['compression'] = 320 torchaudio.save(str(fp), wav, sample_rate, **kwargs) ================================================ FILE: tests/data/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: tests/data/test_audio.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from itertools import product import random import numpy as np import torch import torchaudio from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read from ..common_utils import TempDirMixin, get_white_noise, save_wav class TestInfo(TempDirMixin): def test_info_mp3(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. for sample_rate, ch in product(sample_rates, channels): wav = get_white_noise(ch, int(sample_rate * duration)) path = self.get_temp_path('sample_wav.mp3') save_wav(path, wav, sample_rate) info = audio_info(path) assert info.sample_rate == sample_rate assert info.channels == ch # we cannot trust torchaudio for num_frames, so we don't check def _test_info_format(self, ext: str): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) wav = get_white_noise(ch, n_frames) path = self.get_temp_path(f'sample_wav{ext}') save_wav(path, wav, sample_rate) info = audio_info(path) assert info.sample_rate == sample_rate assert info.channels == ch assert np.isclose(info.duration, duration, atol=1e-5) def test_info_wav(self): self._test_info_format('.wav') def test_info_flac(self): self._test_info_format('.flac') def test_info_ogg(self): self._test_info_format('.ogg') def test_info_m4a(self): # TODO: generate m4a file programmatically # self._test_info_format('.m4a') pass class TestRead(TempDirMixin): def test_read_full_wav(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) path = self.get_temp_path('sample_wav.wav') save_wav(path, wav, sample_rate) read_wav, read_sr = audio_read(path) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[1] == wav.shape[1] assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04) def test_read_partial_wav(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. read_duration = torch.rand(1).item() for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) read_frames = int(sample_rate * read_duration) wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) path = self.get_temp_path('sample_wav.wav') save_wav(path, wav, sample_rate) read_wav, read_sr = audio_read(path, 0, read_duration) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[1] == read_frames assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04) def test_read_seek_time_wav(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. read_duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) path = self.get_temp_path('sample_wav.wav') save_wav(path, wav, sample_rate) seek_time = torch.rand(1).item() read_wav, read_sr = audio_read(path, seek_time, read_duration) seek_frames = int(sample_rate * seek_time) expected_frames = n_frames - seek_frames assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[1] == expected_frames assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04) def test_read_seek_time_wav_padded(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. read_duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) read_frames = int(sample_rate * read_duration) wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) path = self.get_temp_path('sample_wav.wav') save_wav(path, wav, sample_rate) seek_time = torch.rand(1).item() seek_frames = int(sample_rate * seek_time) expected_frames = n_frames - seek_frames read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True) expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[1] == read_frames assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04) assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav) class TestAvRead(TempDirMixin): def test_avread_seek_base(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 2. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) wav = get_white_noise(ch, n_frames) path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav') save_wav(path, wav, sample_rate) for _ in range(100): # seek will always load a full duration segment in the file seek_time = random.uniform(0.0, 1.0) seek_duration = random.uniform(0.001, 1.0) read_wav, read_sr = _av_read(path, seek_time, seek_duration) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[-1] == int(seek_duration * sample_rate) def test_avread_seek_partial(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) wav = get_white_noise(ch, n_frames) path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav') save_wav(path, wav, sample_rate) for _ in range(100): # seek will always load a partial segment seek_time = random.uniform(0.5, 1.) seek_duration = 1. expected_num_frames = n_frames - int(seek_time * sample_rate) read_wav, read_sr = _av_read(path, seek_time, seek_duration) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[-1] == expected_num_frames def test_avread_seek_outofbound(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(sample_rate * duration) wav = get_white_noise(ch, n_frames) path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav') save_wav(path, wav, sample_rate) seek_time = 1.5 read_wav, read_sr = _av_read(path, seek_time, 1.) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[-1] == 0 def test_avread_seek_edge(self): sample_rates = [8000, 16_000] # some of these values will have # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1) n_frames = [1000, 1001, 1002] channels = [1, 2] for sample_rate, ch, frames in product(sample_rates, channels, n_frames): duration = frames / sample_rate wav = get_white_noise(ch, frames) path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav') save_wav(path, wav, sample_rate) seek_time = (frames - 1) / sample_rate seek_frames = int(seek_time * sample_rate) read_wav, read_sr = _av_read(path, seek_time, duration) assert read_sr == sample_rate assert read_wav.shape[0] == wav.shape[0] assert read_wav.shape[-1] == (frames - seek_frames) class TestAudioWrite(TempDirMixin): def test_audio_write_wav(self): torch.manual_seed(1234) sample_rates = [8000, 16_000] n_frames = [1000, 1001, 1002] channels = [1, 2] strategies = ["peak", "clip", "rms"] formats = ["wav", "mp3"] for sample_rate, ch, frames in product(sample_rates, channels, n_frames): for format_, strategy in product(formats, strategies): wav = get_white_noise(ch, frames) path = self.get_temp_path(f'pred_{sample_rate}_{ch}') audio_write(path, wav, sample_rate, format_, strategy=strategy) read_wav, read_sr = torchaudio.load(f'{path}.{format_}') if format_ == "wav": assert read_wav.shape == wav.shape if format_ == "wav" and strategy in ["peak", "rms"]: rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max() # for a Gaussian, the typical max scale will be less than ~5x the std. # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that. # For RMS target, rescaling leaves more headroom by default, leading # to a 20x rescaling typically atol = (5 if strategy == "peak" else 20) / 2**15 delta = (rescaled_read_wav - wav).abs().max() assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol) formats = ["wav"] # faster unit tests ================================================ FILE: tests/data/test_audio_dataset.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from functools import partial from itertools import product import json import math import os import random import typing as tp import pytest import torch from torch.utils.data import DataLoader from audiocraft.data.audio_dataset import ( AudioDataset, AudioMeta, _get_audio_meta, load_audio_meta, save_audio_meta ) from audiocraft.data.zip import PathInZip from ..common_utils import TempDirMixin, get_white_noise, save_wav class TestAudioMeta(TempDirMixin): def test_get_audio_meta(self): sample_rates = [8000, 16_000] channels = [1, 2] duration = 1. for sample_rate, ch in product(sample_rates, channels): n_frames = int(duration * sample_rate) wav = get_white_noise(ch, n_frames) path = self.get_temp_path('sample.wav') save_wav(path, wav, sample_rate) m = _get_audio_meta(path, minimal=True) assert m.path == path, 'path does not match' assert m.sample_rate == sample_rate, 'sample rate does not match' assert m.duration == duration, 'duration does not match' assert m.amplitude is None assert m.info_path is None def test_save_audio_meta(self): audio_meta = [ AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')), AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json')) ] empty_audio_meta = [] for idx, meta in enumerate([audio_meta, empty_audio_meta]): path = self.get_temp_path(f'data_{idx}_save.jsonl') save_audio_meta(path, meta) with open(path, 'r') as f: lines = f.readlines() read_meta = [AudioMeta.from_dict(json.loads(line)) for line in lines] assert len(read_meta) == len(meta) for m, read_m in zip(meta, read_meta): assert m == read_m def test_load_audio_meta(self): try: import dora except ImportError: dora = None # type: ignore audio_meta = [ AudioMeta("mypath1", 1., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file1.json')), AudioMeta("mypath2", 2., 16_000, None, None, PathInZip('/foo/bar.zip:/relative/file2.json')) ] empty_meta = [] for idx, meta in enumerate([audio_meta, empty_meta]): path = self.get_temp_path(f'data_{idx}_load.jsonl') with open(path, 'w') as f: for m in meta: json_str = json.dumps(m.to_dict()) + '\n' f.write(json_str) read_meta = load_audio_meta(path) assert len(read_meta) == len(meta) for m, read_m in zip(meta, read_meta): if dora: m.path = dora.git_save.to_absolute_path(m.path) assert m == read_m, f'original={m}, read={read_m}' class TestAudioDataset(TempDirMixin): def _create_audio_files(self, root_name: str, num_examples: int, durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.), sample_rate: int = 16_000, channels: int = 1): root_dir = self.get_temp_dir(root_name) for i in range(num_examples): if isinstance(durations, float): duration = durations elif isinstance(durations, tuple) and len(durations) == 1: duration = durations[0] elif isinstance(durations, tuple) and len(durations) == 2: duration = random.uniform(durations[0], durations[1]) else: assert False n_frames = int(duration * sample_rate) wav = get_white_noise(channels, n_frames) path = os.path.join(root_dir, f'example_{i}.wav') save_wav(path, wav, sample_rate) return root_dir def _create_audio_dataset(self, root_name: str, total_num_examples: int, durations: tp.Union[float, tp.Tuple[float, float]] = (0.1, 1.), sample_rate: int = 16_000, channels: int = 1, segment_duration: tp.Optional[float] = None, num_examples: int = 10, shuffle: bool = True, return_info: bool = False): root_dir = self._create_audio_files(root_name, total_num_examples, durations, sample_rate, channels) dataset = AudioDataset.from_path(root_dir, minimal_meta=True, segment_duration=segment_duration, num_samples=num_examples, sample_rate=sample_rate, channels=channels, shuffle=shuffle, return_info=return_info) return dataset def test_dataset_full(self): total_examples = 10 min_duration, max_duration = 1., 4. sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, channels=channels, segment_duration=None) assert len(dataset) == total_examples assert dataset.sample_rate == sample_rate assert dataset.channels == channels for idx in range(len(dataset)): sample = dataset[idx] assert sample.shape[0] == channels assert sample.shape[1] <= int(max_duration * sample_rate) assert sample.shape[1] >= int(min_duration * sample_rate) def test_dataset_segment(self): total_examples = 10 num_samples = 20 min_duration, max_duration = 1., 4. segment_duration = 1. sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples) assert len(dataset) == num_samples assert dataset.sample_rate == sample_rate assert dataset.channels == channels for idx in range(len(dataset)): sample = dataset[idx] assert sample.shape[0] == channels assert sample.shape[1] == int(segment_duration * sample_rate) def test_dataset_equal_audio_and_segment_durations(self): total_examples = 1 num_samples = 2 audio_duration = 1. segment_duration = 1. sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples) assert len(dataset) == num_samples assert dataset.sample_rate == sample_rate assert dataset.channels == channels for idx in range(len(dataset)): sample = dataset[idx] assert sample.shape[0] == channels assert sample.shape[1] == int(segment_duration * sample_rate) # the random seek_time adds variability on audio read sample_1 = dataset[0] sample_2 = dataset[1] assert not torch.allclose(sample_1, sample_2) def test_dataset_samples(self): total_examples = 1 num_samples = 2 audio_duration = 1. segment_duration = 1. sample_rate = 16_000 channels = 1 create_dataset = partial( self._create_audio_dataset, 'dset', total_examples, durations=audio_duration, sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples, ) dataset = create_dataset(shuffle=True) # when shuffle = True, we have different inputs for the same index across epoch sample_1 = dataset[0] sample_2 = dataset[0] assert not torch.allclose(sample_1, sample_2) dataset_noshuffle = create_dataset(shuffle=False) # when shuffle = False, we have same inputs for the same index across epoch sample_1 = dataset_noshuffle[0] sample_2 = dataset_noshuffle[0] assert torch.allclose(sample_1, sample_2) def test_dataset_return_info(self): total_examples = 10 num_samples = 20 min_duration, max_duration = 1., 4. segment_duration = 1. sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True) assert len(dataset) == num_samples assert dataset.sample_rate == sample_rate assert dataset.channels == channels for idx in range(len(dataset)): sample, segment_info = dataset[idx] assert sample.shape[0] == channels assert sample.shape[1] == int(segment_duration * sample_rate) assert segment_info.sample_rate == sample_rate assert segment_info.total_frames == int(segment_duration * sample_rate) assert segment_info.n_frames <= int(segment_duration * sample_rate) assert segment_info.seek_time >= 0 def test_dataset_return_info_no_segment_duration(self): total_examples = 10 num_samples = 20 min_duration, max_duration = 1., 4. segment_duration = None sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True) assert len(dataset) == total_examples assert dataset.sample_rate == sample_rate assert dataset.channels == channels for idx in range(len(dataset)): sample, segment_info = dataset[idx] assert sample.shape[0] == channels assert sample.shape[1] == segment_info.total_frames assert segment_info.sample_rate == sample_rate assert segment_info.n_frames <= segment_info.total_frames def test_dataset_collate_fn(self): total_examples = 10 num_samples = 20 min_duration, max_duration = 1., 4. segment_duration = 1. sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=False) batch_size = 4 dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=0 ) for idx, batch in enumerate(dataloader): assert batch.shape[0] == batch_size @pytest.mark.parametrize("segment_duration", [1.0, None]) def test_dataset_with_meta_collate_fn(self, segment_duration): total_examples = 10 num_samples = 20 min_duration, max_duration = 1., 4. segment_duration = 1. sample_rate = 16_000 channels = 1 dataset = self._create_audio_dataset( 'dset', total_examples, durations=(min_duration, max_duration), sample_rate=sample_rate, channels=channels, segment_duration=segment_duration, num_examples=num_samples, return_info=True) batch_size = 4 dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=dataset.collater, num_workers=0 ) for idx, batch in enumerate(dataloader): wav, infos = batch assert wav.shape[0] == batch_size assert len(infos) == batch_size @pytest.mark.parametrize("segment_duration,sample_on_weight,sample_on_duration,a_hist,b_hist,c_hist", [ [1, True, True, 0.5, 0.5, 0.0], [1, False, True, 0.25, 0.5, 0.25], [1, True, False, 0.666, 0.333, 0.0], [1, False, False, 0.333, 0.333, 0.333], [None, False, False, 0.333, 0.333, 0.333]]) def test_sample_with_weight(self, segment_duration, sample_on_weight, sample_on_duration, a_hist, b_hist, c_hist): random.seed(1234) rng = torch.Generator() rng.manual_seed(1234) def _get_histogram(dataset, repetitions=20_000): counts = {file_meta.path: 0. for file_meta in meta} for _ in range(repetitions): file_meta = dataset.sample_file(0, rng) counts[file_meta.path] += 1 return {name: count / repetitions for name, count in counts.items()} meta = [ AudioMeta(path='a', duration=5, sample_rate=1, weight=2), AudioMeta(path='b', duration=10, sample_rate=1, weight=None), AudioMeta(path='c', duration=5, sample_rate=1, weight=0), ] dataset = AudioDataset( meta, segment_duration=segment_duration, sample_on_weight=sample_on_weight, sample_on_duration=sample_on_duration) hist = _get_histogram(dataset) assert math.isclose(hist['a'], a_hist, abs_tol=0.01) assert math.isclose(hist['b'], b_hist, abs_tol=0.01) assert math.isclose(hist['c'], c_hist, abs_tol=0.01) def test_meta_duration_filter_all(self): meta = [ AudioMeta(path='a', duration=5, sample_rate=1, weight=2), AudioMeta(path='b', duration=10, sample_rate=1, weight=None), AudioMeta(path='c', duration=5, sample_rate=1, weight=0), ] try: AudioDataset(meta, segment_duration=11, min_segment_ratio=1) assert False except AssertionError: assert True def test_meta_duration_filter_long(self): meta = [ AudioMeta(path='a', duration=5, sample_rate=1, weight=2), AudioMeta(path='b', duration=10, sample_rate=1, weight=None), AudioMeta(path='c', duration=5, sample_rate=1, weight=0), ] dataset = AudioDataset(meta, segment_duration=None, min_segment_ratio=1, max_audio_duration=7) assert len(dataset) == 2 ================================================ FILE: tests/data/test_audio_utils.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import julius import torch import pytest from audiocraft.data.audio_utils import ( _clip_wav, convert_audio_channels, convert_audio, normalize_audio ) from ..common_utils import get_batch_white_noise class TestConvertAudioChannels: def test_convert_audio_channels_downmix(self): b, c, t = 2, 3, 100 audio = get_batch_white_noise(b, c, t) mixed = convert_audio_channels(audio, channels=2) assert list(mixed.shape) == [b, 2, t] def test_convert_audio_channels_nochange(self): b, c, t = 2, 3, 100 audio = get_batch_white_noise(b, c, t) mixed = convert_audio_channels(audio, channels=c) assert list(mixed.shape) == list(audio.shape) def test_convert_audio_channels_upmix(self): b, c, t = 2, 1, 100 audio = get_batch_white_noise(b, c, t) mixed = convert_audio_channels(audio, channels=3) assert list(mixed.shape) == [b, 3, t] def test_convert_audio_channels_upmix_error(self): b, c, t = 2, 2, 100 audio = get_batch_white_noise(b, c, t) with pytest.raises(ValueError): convert_audio_channels(audio, channels=3) class TestConvertAudio: def test_convert_audio_channels_downmix(self): b, c, dur = 2, 3, 4. sr = 128 audio = get_batch_white_noise(b, c, int(sr * dur)) out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2) assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]] def test_convert_audio_channels_upmix(self): b, c, dur = 2, 1, 4. sr = 128 audio = get_batch_white_noise(b, c, int(sr * dur)) out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3) assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]] def test_convert_audio_upsample(self): b, c, dur = 2, 1, 4. sr = 2 new_sr = 3 audio = get_batch_white_noise(b, c, int(sr * dur)) out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) assert torch.allclose(out, out_j) def test_convert_audio_resample(self): b, c, dur = 2, 1, 4. sr = 3 new_sr = 2 audio = get_batch_white_noise(b, c, int(sr * dur)) out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) assert torch.allclose(out, out_j) class TestNormalizeAudio: def test_clip_wav(self): b, c, dur = 2, 1, 4. sr = 3 audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) _clip_wav(audio) assert audio.abs().max() <= 1 def test_normalize_audio_clip(self): b, c, dur = 2, 1, 4. sr = 3 audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) norm_audio = normalize_audio(audio, strategy='clip') assert norm_audio.abs().max() <= 1 def test_normalize_audio_rms(self): b, c, dur = 2, 1, 4. sr = 3 audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) norm_audio = normalize_audio(audio, strategy='rms') assert norm_audio.abs().max() <= 1 def test_normalize_audio_peak(self): b, c, dur = 2, 1, 4. sr = 3 audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) norm_audio = normalize_audio(audio, strategy='peak') assert norm_audio.abs().max() <= 1 ================================================ FILE: tests/losses/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: tests/losses/test_losses.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import torch from audiocraft.losses import ( MelSpectrogramL1Loss, MultiScaleMelSpectrogramLoss, MRSTFTLoss, SISNR, STFTLoss, ) def test_mel_l1_loss(): N, C, T = 2, 2, random.randrange(1000, 100_000) t1 = torch.randn(N, C, T) t2 = torch.randn(N, C, T) mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) loss = mel_l1(t1, t2) loss_same = mel_l1(t1, t1) assert isinstance(loss, torch.Tensor) assert isinstance(loss_same, torch.Tensor) assert loss_same.item() == 0.0 def test_msspec_loss(): N, C, T = 2, 2, random.randrange(1000, 100_000) t1 = torch.randn(N, C, T) t2 = torch.randn(N, C, T) msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) loss = msspec(t1, t2) loss_same = msspec(t1, t1) assert isinstance(loss, torch.Tensor) assert isinstance(loss_same, torch.Tensor) assert loss_same.item() == 0.0 def test_mrstft_loss(): N, C, T = 2, 2, random.randrange(1000, 100_000) t1 = torch.randn(N, C, T) t2 = torch.randn(N, C, T) mrstft = MRSTFTLoss() loss = mrstft(t1, t2) assert isinstance(loss, torch.Tensor) def test_sisnr_loss(): N, C, T = 2, 2, random.randrange(1000, 100_000) t1 = torch.randn(N, C, T) t2 = torch.randn(N, C, T) sisnr = SISNR() loss = sisnr(t1, t2) assert isinstance(loss, torch.Tensor) def test_stft_loss(): N, C, T = 2, 2, random.randrange(1000, 100_000) t1 = torch.randn(N, C, T) t2 = torch.randn(N, C, T) mrstft = STFTLoss() loss = mrstft(t1, t2) assert isinstance(loss, torch.Tensor) ================================================ FILE: tests/models/test_audiogen.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from audiocraft.models import AudioGen class TestAudioGenModel: def get_audiogen(self): ag = AudioGen.get_pretrained(name='debug', device='cpu') ag.set_generation_params(duration=2.0, extend_stride=2.) return ag def test_base(self): ag = self.get_audiogen() assert ag.frame_rate == 25 assert ag.sample_rate == 16000 assert ag.audio_channels == 1 def test_generate_continuation(self): ag = self.get_audiogen() prompt = torch.randn(3, 1, 16000) wav = ag.generate_continuation(prompt, 16000) assert list(wav.shape) == [3, 1, 32000] prompt = torch.randn(2, 1, 16000) wav = ag.generate_continuation( prompt, 16000, ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 32000] prompt = torch.randn(2, 1, 16000) with pytest.raises(AssertionError): wav = ag.generate_continuation( prompt, 16000, ['youpi', 'lapin dort', 'one too many']) def test_generate(self): ag = self.get_audiogen() wav = ag.generate( ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 32000] def test_generate_long(self): ag = self.get_audiogen() ag.max_duration = 3. ag.set_generation_params(duration=4., extend_stride=2.) wav = ag.generate( ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 16000 * 4] ================================================ FILE: tests/models/test_encodec_model.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import numpy as np import torch from audiocraft.models import EncodecModel from audiocraft.modules import SEANetEncoder, SEANetDecoder from audiocraft.quantization import DummyQuantizer class TestEncodecModel: def _create_encodec_model(self, sample_rate: int, channels: int, dim: int = 5, n_filters: int = 3, n_residual_layers: int = 1, ratios: list = [5, 4, 3, 2], **kwargs): frame_rate = np.prod(ratios) encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters, n_residual_layers=n_residual_layers, ratios=ratios) decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters, n_residual_layers=n_residual_layers, ratios=ratios) quantizer = DummyQuantizer() model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, sample_rate=sample_rate, channels=channels, **kwargs) return model def test_model(self): random.seed(1234) sample_rate = 24_000 channels = 1 model = self._create_encodec_model(sample_rate, channels) for _ in range(10): length = random.randrange(1, 10_000) x = torch.randn(2, channels, length) res = model(x) assert res.x.shape == x.shape def test_model_renorm(self): random.seed(1234) sample_rate = 24_000 channels = 1 model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False) model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True) for _ in range(10): length = random.randrange(1, 10_000) x = torch.randn(2, channels, length) codes, scales = model_nonorm.encode(x) codes, scales = model_renorm.encode(x) assert scales is not None ================================================ FILE: tests/models/test_multibanddiffusion.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import numpy as np import torch from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess from audiocraft.models import EncodecModel, DiffusionUnet from audiocraft.modules import SEANetEncoder, SEANetDecoder from audiocraft.modules.diffusion_schedule import NoiseSchedule from audiocraft.quantization import DummyQuantizer class TestMBD: def _create_mbd(self, sample_rate: int, channels: int, n_filters: int = 3, n_residual_layers: int = 1, ratios: list = [5, 4, 3, 2], num_steps: int = 1000, codec_dim: int = 128, **kwargs): frame_rate = np.prod(ratios) encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters, n_residual_layers=n_residual_layers, ratios=ratios) decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters, n_residual_layers=n_residual_layers, ratios=ratios) quantizer = DummyQuantizer() compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, sample_rate=sample_rate, channels=channels, **kwargs) diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim) schedule = NoiseSchedule(device='cpu', num_steps=num_steps) DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule) mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model) return mbd def test_model(self): random.seed(1234) sample_rate = 24_000 channels = 1 codec_dim = 128 mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim) for _ in range(10): length = random.randrange(1, 10_000) x = torch.randn(2, channels, length) res = mbd.regenerate(x, sample_rate) assert res.shape == x.shape ================================================ FILE: tests/models/test_musicgen.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from audiocraft.models import MusicGen class TestMusicGenModel: def get_musicgen(self): mg = MusicGen.get_pretrained(name='debug', device='cpu') mg.set_generation_params(duration=2.0, extend_stride=2.) return mg def test_base(self): mg = self.get_musicgen() assert mg.frame_rate == 25 assert mg.sample_rate == 32000 assert mg.audio_channels == 1 def test_generate_unconditional(self): mg = self.get_musicgen() wav = mg.generate_unconditional(3) assert list(wav.shape) == [3, 1, 64000] def test_generate_continuation(self): mg = self.get_musicgen() prompt = torch.randn(3, 1, 32000) wav = mg.generate_continuation(prompt, 32000) assert list(wav.shape) == [3, 1, 64000] prompt = torch.randn(2, 1, 32000) wav = mg.generate_continuation( prompt, 32000, ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 64000] prompt = torch.randn(2, 1, 32000) with pytest.raises(AssertionError): wav = mg.generate_continuation( prompt, 32000, ['youpi', 'lapin dort', 'one too many']) def test_generate(self): mg = self.get_musicgen() wav = mg.generate( ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 64000] def test_generate_long(self): mg = self.get_musicgen() mg.max_duration = 3. mg.set_generation_params(duration=4., extend_stride=2.) wav = mg.generate( ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 32000 * 4] ================================================ FILE: tests/modules/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ================================================ FILE: tests/modules/test_activations.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from audiocraft.modules.activations import CustomGLU class TestActivations: def test_custom_glu_calculation(self): activation = CustomGLU(nn.Identity()) initial_shape = (4, 8, 8) part_a = torch.ones(initial_shape) * 2 part_b = torch.ones(initial_shape) * -1 input = torch.cat((part_a, part_b), dim=-1) output = activation(input) # ensure all dimensions match initial shape assert output.shape == initial_shape # ensure the gating was calculated correctly a * f(b) assert torch.all(output == -2).item() ================================================ FILE: tests/modules/test_codebooks_patterns.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from audiocraft.modules.codebooks_patterns import ( DelayedPatternProvider, ParallelPatternProvider, Pattern, UnrolledPatternProvider, ) class TestParallelPatternProvider: @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [0, 1, 16, 100]) def test_get_pattern(self, n_q: int, timesteps: int): provider = ParallelPatternProvider(n_q) pattern = provider.get_pattern(timesteps) # + 1 to account for 1st step assert len(pattern.layout) == timesteps + 1 @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [8, 16, 100]) def test_pattern_content(self, n_q: int, timesteps: int): provider = ParallelPatternProvider(n_q) pattern = provider.get_pattern(timesteps) for s, v in enumerate(pattern.layout): for i, code in enumerate(v): assert i == code.q assert code.t == s - 1 # account for the 1st empty step @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [8, 16, 100]) def test_pattern_max_delay(self, n_q: int, timesteps: int): provider = ParallelPatternProvider(n_q) pattern = provider.get_pattern(timesteps) assert pattern.max_delay == 0 assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay class TestDelayedPatternProvider: @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [0, 1, 16, 100]) def test_get_pattern(self, n_q: int, timesteps: int): delays = [ list(range(n_q)), [0] + [1] * (n_q - 1), [0] + [4] * (n_q - 1), ] for delay in delays: provider = DelayedPatternProvider(n_q, delay) pattern = provider.get_pattern(timesteps) # + 1 to account for 1st step assert len(pattern.layout) == timesteps + max(delay) + 1 @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [8, 16, 100]) def test_pattern_content(self, n_q: int, timesteps: int): provider = DelayedPatternProvider(n_q) pattern = provider.get_pattern(timesteps) for s, v in enumerate(pattern.layout): for i, code in enumerate(v): assert i == code.q assert code.t == max(0, s - code.q - 1) @pytest.mark.parametrize("timesteps", [8, 16, 100]) @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]]) def test_pattern_max_delay(self, timesteps: int, delay: list): provider = DelayedPatternProvider(len(delay), delay) pattern = provider.get_pattern(timesteps) assert pattern.max_delay == max(delay) assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay class TestUnrolledPatternProvider: @pytest.mark.parametrize("timesteps", [0, 1, 16]) @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]]) @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]]) def test_get_pattern(self, timesteps: int, flattening: list, delays: list): n_q = len(flattening) max_delay = max(delays) provider = UnrolledPatternProvider(n_q, flattening, delays) pattern = provider.get_pattern(timesteps) assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay @pytest.mark.parametrize("timesteps", [0, 1, 16]) @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]]) @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]]) def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list): n_q = len(flattening) max_delay = max(delays) provider = UnrolledPatternProvider(n_q, flattening, delays) pattern = provider.get_pattern(timesteps) assert pattern.max_delay == max_delay class TestPattern: def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): """Reference method to build the sequence from the pattern without using fancy scatter.""" bs, n_q, T = z.shape z = z.cpu().numpy() assert n_q == pattern.n_q assert T <= pattern.timesteps inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy() inp[:] = special_token for s, v in enumerate(pattern.layout): for (t, q) in v: if t < T: inp[:, q, s] = z[:, q, t] return torch.from_numpy(inp) def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): """Reference method to revert the sequence from the pattern without using fancy scatter.""" z = z.cpu().numpy() bs, n_q, S = z.shape assert pattern.n_q == n_q inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy() inp[:] = special_token for s, v in enumerate(pattern.layout): for (t, q) in v: if t < pattern.timesteps: inp[:, q, t] = z[:, q, s] return torch.from_numpy(inp) def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float): """Reference method to revert the logits from the pattern without using fancy scatter.""" z = z.cpu().numpy() bs, card, n_q, S = z.shape assert pattern.n_q == n_q ref_layout = pattern.layout inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy() inp[:] = special_token for s, v in enumerate(ref_layout[1:]): if s < S: for (t, q) in v: if t < pattern.timesteps: inp[:, :, q, t] = z[:, :, q, s] return torch.from_numpy(inp) def _get_pattern_providers(self, n_q: int): pattern_provider_1 = ParallelPatternProvider(n_q) pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q))) pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1)) pattern_provider_4 = UnrolledPatternProvider( n_q, flattening=list(range(n_q)), delays=[0] * n_q ) pattern_provider_5 = UnrolledPatternProvider( n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q ) pattern_provider_6 = UnrolledPatternProvider( n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1) ) return [ pattern_provider_1, pattern_provider_2, pattern_provider_3, pattern_provider_4, pattern_provider_5, pattern_provider_6, ] @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [16, 72]) def test_build_pattern_sequence(self, n_q: int, timesteps: int): bs = 2 card = 256 special_token = card pattern_providers = self._get_pattern_providers(n_q) for pattern_provider in pattern_providers: pattern = pattern_provider.get_pattern(timesteps) # we can correctly build the sequence from the pattern z = torch.randint(0, card, (bs, n_q, timesteps)) ref_res = self.ref_build_pattern_sequence(z, pattern, special_token) res, indexes, mask = pattern.build_pattern_sequence(z, special_token) assert (res == ref_res).float().mean() == 1.0 # expected assertion fails on the number of timesteps invalid_timesteps = [timesteps + 1] if pattern.num_sequence_steps != pattern.timesteps: invalid_timesteps.append(pattern.num_sequence_steps) for i_timesteps in invalid_timesteps: z2 = torch.randint(0, card, (bs, n_q, i_timesteps)) with pytest.raises(AssertionError): pattern.build_pattern_sequence(z2, special_token) # expected assertion fails on the number of codebooks invalid_qs = [0, n_q - 1, n_q + 1] for i_q in invalid_qs: z3 = torch.randint(0, card, (bs, i_q, timesteps)) with pytest.raises(AssertionError): pattern.build_pattern_sequence(z3, special_token) @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [16, 72]) def test_revert_pattern_sequence(self, n_q: int, timesteps: int): bs = 2 card = 256 special_token = card pattern_providers = self._get_pattern_providers(n_q) for pattern_provider in pattern_providers: pattern = pattern_provider.get_pattern(timesteps) # this works assuming previous tests are successful z = torch.randint(0, card, (bs, n_q, timesteps)) s = self.ref_build_pattern_sequence(z, pattern, special_token) ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token) # ensure our reference script retrieve the original sequence assert z.shape == ref_out.shape assert (z == ref_out).float().mean() == 1.0 # now we can test the scatter version out, indexes, mask = pattern.revert_pattern_sequence(s, special_token) assert out.shape == ref_out.shape assert (out == ref_out).float().mean() == 1.0 @pytest.mark.parametrize("n_q", [1, 4, 32]) @pytest.mark.parametrize("timesteps", [16, 72]) @pytest.mark.parametrize("card", [1, 2, 256, 1024]) def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int): bs = 2 special_token = card logits_special_token = float('nan') pattern_providers = self._get_pattern_providers(n_q) for pattern_provider in pattern_providers: pattern = pattern_provider.get_pattern(timesteps) # this works assuming previous tests are successful z = torch.randint(0, card, (bs, n_q, timesteps)) s = self.ref_build_pattern_sequence(z, pattern, special_token) logits = torch.randn((bs, card, n_q, s.shape[-1])) ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token) # ensure our reference script retrieve the original sequence assert ref_out.shape == torch.Size([bs, card, n_q, timesteps]) # now we can test the scatter version out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token) assert out.shape == ref_out.shape assert (out == ref_out).float().mean() == 1.0 ================================================ FILE: tests/modules/test_conv.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from itertools import product import math import random import pytest import torch from torch import nn from audiocraft.modules import ( NormConv1d, NormConvTranspose1d, StreamableConv1d, StreamableConvTranspose1d, pad1d, unpad1d, ) def test_get_extra_padding_for_conv1d(): # TODO: Implement me! pass def test_pad1d_zeros(): x = torch.randn(1, 1, 20) xp1 = pad1d(x, (0, 5), mode='constant', value=0.) assert xp1.shape[-1] == 25 xp2 = pad1d(x, (5, 5), mode='constant', value=0.) assert xp2.shape[-1] == 30 xp3 = pad1d(x, (0, 0), mode='constant', value=0.) assert xp3.shape[-1] == 20 xp4 = pad1d(x, (10, 30), mode='constant', value=0.) assert xp4.shape[-1] == 60 with pytest.raises(AssertionError): pad1d(x, (-1, 0), mode='constant', value=0.) with pytest.raises(AssertionError): pad1d(x, (0, -1), mode='constant', value=0.) with pytest.raises(AssertionError): pad1d(x, (-1, -1), mode='constant', value=0.) def test_pad1d_reflect(): x = torch.randn(1, 1, 20) xp1 = pad1d(x, (0, 5), mode='reflect', value=0.) assert xp1.shape[-1] == 25 xp2 = pad1d(x, (5, 5), mode='reflect', value=0.) assert xp2.shape[-1] == 30 xp3 = pad1d(x, (0, 0), mode='reflect', value=0.) assert xp3.shape[-1] == 20 xp4 = pad1d(x, (10, 30), mode='reflect', value=0.) assert xp4.shape[-1] == 60 with pytest.raises(AssertionError): pad1d(x, (-1, 0), mode='reflect', value=0.) with pytest.raises(AssertionError): pad1d(x, (0, -1), mode='reflect', value=0.) with pytest.raises(AssertionError): pad1d(x, (-1, -1), mode='reflect', value=0.) def test_unpad1d(): x = torch.randn(1, 1, 20) u1 = unpad1d(x, (5, 5)) assert u1.shape[-1] == 10 u2 = unpad1d(x, (0, 5)) assert u2.shape[-1] == 15 u3 = unpad1d(x, (5, 0)) assert u3.shape[-1] == 15 u4 = unpad1d(x, (0, 0)) assert u4.shape[-1] == x.shape[-1] with pytest.raises(AssertionError): unpad1d(x, (-1, 0)) with pytest.raises(AssertionError): unpad1d(x, (0, -1)) with pytest.raises(AssertionError): unpad1d(x, (-1, -1)) class TestNormConv1d: def test_norm_conv1d_modules(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) C_out, kernel_size, stride = 1, 4, 1 expected_out_length = int((T - kernel_size) / stride + 1) wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm') gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm') nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none') assert isinstance(wn_conv.norm, nn.Identity) assert isinstance(wn_conv.conv, nn.Conv1d) assert isinstance(gn_conv.norm, nn.GroupNorm) assert isinstance(gn_conv.conv, nn.Conv1d) assert isinstance(nn_conv.norm, nn.Identity) assert isinstance(nn_conv.conv, nn.Conv1d) for conv_layer in [wn_conv, gn_conv, nn_conv]: out = conv_layer(t0) assert isinstance(out, torch.Tensor) assert list(out.shape) == [N, C_out, expected_out_length] class TestNormConvTranspose1d: def test_normalizations(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) C_out, kernel_size, stride = 1, 4, 1 expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1 wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm') gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm') nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none') assert isinstance(wn_convtr.norm, nn.Identity) assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d) assert isinstance(gn_convtr.norm, nn.GroupNorm) assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d) assert isinstance(nn_convtr.norm, nn.Identity) assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d) for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]: out = convtr_layer(t0) assert isinstance(out, torch.Tensor) assert list(out.shape) == [N, C_out, expected_out_length] class TestStreamableConv1d: def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation): # StreamableConv1d internally pads to make sure that the last window is full padding_total = (kernel_size - 1) * dilation - (stride - 1) n_frames = (length - kernel_size + padding_total) / stride + 1 ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) return ideal_length // stride def test_streamable_conv1d(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) C_out = 1 # conv params are [(kernel_size, stride, dilation)] conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)] for causal, (kernel_size, stride, dilation) in product([False, True], conv_params): expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation) sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal) out = sconv(t0) assert isinstance(out, torch.Tensor) print(list(out.shape), [N, C_out, expected_out_length]) assert list(out.shape) == [N, C_out, expected_out_length] class TestStreamableConvTranspose1d: def get_streamable_convtr1d_output_length(self, length, kernel_size, stride): padding_total = (kernel_size - stride) return (length - 1) * stride - padding_total + (kernel_size - 1) + 1 def test_streamable_convtr1d(self): N, C, T = 2, 2, random.randrange(1, 100_000) t0 = torch.randn(N, C, T) C_out = 1 with pytest.raises(AssertionError): StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5) StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.) StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2) # causal params are [(causal, trim_right)] causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)] # conv params are [(kernel_size, stride)] conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)] for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params): expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride) sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, causal=causal, trim_right_ratio=trim_right_ratio) out = sconvtr(t0) assert isinstance(out, torch.Tensor) assert list(out.shape) == [N, C_out, expected_out_length] ================================================ FILE: tests/modules/test_lstm.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import torch from audiocraft.modules.lstm import StreamableLSTM class TestStreamableLSTM: def test_lstm(self): B, C, T = 4, 2, random.randint(1, 100) lstm = StreamableLSTM(C, 3, skip=False) x = torch.randn(B, C, T) y = lstm(x) print(y.shape) assert y.shape == torch.Size([B, C, T]) def test_lstm_skip(self): B, C, T = 4, 2, random.randint(1, 100) lstm = StreamableLSTM(C, 3, skip=True) x = torch.randn(B, C, T) y = lstm(x) assert y.shape == torch.Size([B, C, T]) ================================================ FILE: tests/modules/test_rope.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from audiocraft.modules.rope import RotaryEmbedding from audiocraft.modules.transformer import StreamingTransformer, set_efficient_attention_backend def test_rope(): set_efficient_attention_backend('xformers') B, T, H, C = 8, 75, 16, 128 rope = RotaryEmbedding(dim=C) xq = torch.rand((B, T, H, C)) xk = torch.rand((B, T, H, C)) xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) assert list(xq_out.shape) == [B, T, H, C] assert list(xk_out.shape) == [B, T, H, C] def test_rope_io_dtypes(): set_efficient_attention_backend('xformers') B, T, H, C = 8, 75, 16, 128 rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32) rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64) # Test bfloat16 inputs w/ both 32 and 64 precision rope. xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16) assert xq_out.dtype == torch.bfloat16 xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16) assert xq_out.dtype == torch.bfloat16 # Test float32 inputs w/ both 32 and 64 precision rope. xq_32 = torch.rand((B, T, H, C)).to(torch.float32) xk_32 = torch.rand((B, T, H, C)).to(torch.float32) xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32) assert xq_out.dtype == torch.float32 xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32) assert xq_out.dtype == torch.float32 def test_transformer_with_rope(): set_efficient_attention_backend('xformers') torch.manual_seed(1234) for pos in ['rope', 'sin_rope']: tr = StreamingTransformer( 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, positional_embedding=pos) tr.eval() steps = 12 x = torch.randn(3, steps, 16) out = tr(x) assert list(out.shape) == list(x.shape) @torch.no_grad() def test_rope_streaming(): set_efficient_attention_backend('xformers') torch.manual_seed(1234) tr = StreamingTransformer( 16, 4, 2, causal=True, dropout=0., custom=True, positional_embedding='rope') tr.eval() steps = 12 x = torch.randn(3, steps, 16) ref = tr(x) with tr.streaming(): outs = [] frame_sizes = [1] * steps for frame_size in frame_sizes: frame = x[:, :frame_size] x = x[:, frame_size:] outs.append(tr(frame)) out = torch.cat(outs, dim=1) assert list(out.shape) == [3, steps, 16] delta = torch.norm(out - ref) / torch.norm(out) assert delta < 1e-6, delta @torch.no_grad() def test_rope_streaming_past_context(): set_efficient_attention_backend('xformers') torch.manual_seed(1234) for context in [None, 10]: tr = StreamingTransformer( 16, 4, 1 if context else 2, causal=True, past_context=context, custom=True, dropout=0., positional_embedding='rope') tr.eval() steps = 20 x = torch.randn(3, steps, 16) ref = tr(x) with tr.streaming(): outs = [] frame_sizes = [1] * steps for frame_size in frame_sizes: frame = x[:, :frame_size] x = x[:, frame_size:] outs.append(tr(frame)) out = torch.cat(outs, dim=1) assert list(out.shape) == [3, steps, 16] delta = torch.norm(out - ref) / torch.norm(out) assert delta < 1e-6, delta def test_rope_memory_efficient(): set_efficient_attention_backend('xformers') torch.manual_seed(1234) tr = StreamingTransformer( 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, positional_embedding='rope') tr_mem_efficient = StreamingTransformer( 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1, positional_embedding='rope') tr_mem_efficient.load_state_dict(tr.state_dict()) tr.eval() steps = 12 x = torch.randn(3, steps, 16) with torch.no_grad(): y = tr(x) y2 = tr_mem_efficient(x) # Check at float precision b/c this is the rope default. assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm() def test_rope_with_xpos(): set_efficient_attention_backend('xformers') B, T, H, C = 8, 75, 16, 128 rope = RotaryEmbedding(dim=C, xpos=True) xq = torch.rand((B, T, H, C)) xk = torch.rand((B, T, H, C)) xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) assert list(xq_out.shape) == [B, T, H, C] assert list(xk_out.shape) == [B, T, H, C] def test_positional_scale(): set_efficient_attention_backend('xformers') B, T, H, C = 8, 75, 16, 128 rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0) xq = torch.rand((B, T, H, C)) xk = torch.rand((B, T, H, C)) xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) assert torch.allclose(xq, xq_out) assert torch.allclose(xk, xk_out) ================================================ FILE: tests/modules/test_seanet.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from itertools import product import pytest import torch from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d class TestSEANetModel: def test_base(self): encoder = SEANetEncoder() decoder = SEANetDecoder() x = torch.randn(1, 1, 24000) z = encoder(x) assert list(z.shape) == [1, 128, 75], z.shape y = decoder(z) assert y.shape == x.shape, (x.shape, y.shape) def test_causal(self): encoder = SEANetEncoder(causal=True) decoder = SEANetDecoder(causal=True) x = torch.randn(1, 1, 24000) z = encoder(x) assert list(z.shape) == [1, 128, 75], z.shape y = decoder(z) assert y.shape == x.shape, (x.shape, y.shape) def test_conv_skip_connection(self): encoder = SEANetEncoder(true_skip=False) decoder = SEANetDecoder(true_skip=False) x = torch.randn(1, 1, 24000) z = encoder(x) assert list(z.shape) == [1, 128, 75], z.shape y = decoder(z) assert y.shape == x.shape, (x.shape, y.shape) def test_seanet_encoder_decoder_final_act(self): encoder = SEANetEncoder(true_skip=False) decoder = SEANetDecoder(true_skip=False, final_activation='Tanh') x = torch.randn(1, 1, 24000) z = encoder(x) assert list(z.shape) == [1, 128, 75], z.shape y = decoder(z) assert y.shape == x.shape, (x.shape, y.shape) def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str): n_blocks = 0 for layer in encoder.model: if isinstance(layer, StreamableConv1d): n_blocks += 1 assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm elif isinstance(layer, SEANetResnetBlock): for resnet_layer in layer.block: if isinstance(resnet_layer, StreamableConv1d): # here we add + 1 to n_blocks as we increment n_blocks just after the block assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm def test_encoder_disable_norm(self): n_residuals = [0, 1, 3] disable_blocks = [0, 1, 2, 3, 4, 5, 6] norms = ['weight_norm', 'none'] for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm, disable_norm_outer_blocks=disable_blocks) self._check_encoder_blocks_norm(encoder, disable_blocks, norm) def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str): n_blocks = 0 for layer in decoder.model: if isinstance(layer, StreamableConv1d): n_blocks += 1 assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm elif isinstance(layer, StreamableConvTranspose1d): n_blocks += 1 assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm elif isinstance(layer, SEANetResnetBlock): for resnet_layer in layer.block: if isinstance(resnet_layer, StreamableConv1d): assert resnet_layer.conv.norm_type == 'none' \ if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm def test_decoder_disable_norm(self): n_residuals = [0, 1, 3] disable_blocks = [0, 1, 2, 3, 4, 5, 6] norms = ['weight_norm', 'none'] for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm, disable_norm_outer_blocks=disable_blocks) self._check_decoder_blocks_norm(decoder, disable_blocks, norm) def test_disable_norm_raises_exception(self): # Invalid disable_norm_outer_blocks values raise exceptions with pytest.raises(AssertionError): SEANetEncoder(disable_norm_outer_blocks=-1) with pytest.raises(AssertionError): SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) with pytest.raises(AssertionError): SEANetDecoder(disable_norm_outer_blocks=-1) with pytest.raises(AssertionError): SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) ================================================ FILE: tests/modules/test_transformer.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from itertools import product import pytest import torch from audiocraft.modules.transformer import ( StreamingMultiheadAttention, StreamingTransformer, set_efficient_attention_backend) def test_transformer_causal_streaming(): torch.manual_seed(1234) for context, custom in product([None, 10], [False, True]): # Test that causality and receptive fields are properly handled. # looking at the gradients tr = StreamingTransformer( 16, 4, 1 if context else 2, causal=True, past_context=context, custom=custom, dropout=0.) steps = 20 for k in [0, 10, 15, 19]: x = torch.randn(4, steps, 16, requires_grad=True) y = tr(x) y[:, k].abs().sum().backward() if k + 1 < steps: assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm() assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm() if context is not None and k > context: limit = k - context - 1 assert torch.allclose(x.grad[:, :limit], torch.tensor(0.)), x.grad[:, :limit].norm() # Now check that streaming gives the same result at batch eval. x = torch.randn(4, steps, 16) y = tr(x) ys = [] with tr.streaming(): for k in range(steps): chunk = x[:, k:k + 1, :] ys.append(tr(chunk)) y_stream = torch.cat(ys, dim=1) delta = torch.norm(y_stream - y) / torch.norm(y) assert delta < 1e-6, delta def test_transformer_vs_pytorch(): torch.manual_seed(1234) # Check that in the non causal setting, we get the same result as # PyTorch Transformer encoder. for custom in [False, True]: tr = StreamingTransformer( 16, 4, 2, causal=False, custom=custom, dropout=0., positional_scale=0.) layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True) tr_ref = torch.nn.TransformerEncoder(layer, 2) tr.load_state_dict(tr_ref.state_dict()) x = torch.randn(4, 20, 16) y = tr(x) y2 = tr_ref(x) delta = torch.norm(y2 - y) / torch.norm(y) assert delta < 1e-6, delta def test_streaming_api(): tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.) tr.eval() steps = 12 x = torch.randn(1, steps, 16) with torch.no_grad(): with tr.streaming(): _ = tr(x[:, :1]) state = {k: v.clone() for k, v in tr.get_streaming_state().items()} y = tr(x[:, 1:2]) tr.set_streaming_state(state) y2 = tr(x[:, 1:2]) assert torch.allclose(y, y2), (y - y2).norm() assert tr.flush() is None def test_memory_efficient(): for backend in ['torch', 'xformers']: torch.manual_seed(1234) set_efficient_attention_backend(backend) tr = StreamingTransformer( 16, 4, 2, custom=True, dropout=0., layer_scale=0.1) tr_mem_efficient = StreamingTransformer( 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1) tr_mem_efficient.load_state_dict(tr.state_dict()) tr.eval() steps = 12 x = torch.randn(3, steps, 16) with torch.no_grad(): y = tr(x) y2 = tr_mem_efficient(x) assert torch.allclose(y, y2), ((y - y2).norm(), backend) def test_attention_as_float32(): torch.manual_seed(1234) cases = [ {'custom': True}, {'custom': False}, ] for case in cases: tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case) tr_float32 = StreamingTransformer( 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case) if not case['custom']: # we are not using autocast here because it doesn't really # work as expected on CPU, so we have to manually cast the weights of the MHA. for layer in tr_float32.layers: layer.self_attn.mha.to(torch.float32) tr_float32.load_state_dict(tr.state_dict()) steps = 12 x = torch.randn(3, steps, 16, dtype=torch.bfloat16) with torch.no_grad(): y = tr(x) y2 = tr_float32(x) assert not torch.allclose(y, y2), (y - y2).norm() @torch.no_grad() def test_streaming_memory_efficient(): for backend in ['torch', 'xformers']: torch.manual_seed(1234) set_efficient_attention_backend(backend) tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True) tr_mem_efficient = StreamingTransformer( 16, 4, 2, dropout=0., memory_efficient=True, causal=True) tr.load_state_dict(tr_mem_efficient.state_dict()) tr.eval() tr_mem_efficient.eval() steps = 12 x = torch.randn(3, steps, 16) ref = tr(x) with tr_mem_efficient.streaming(): outs = [] # frame_sizes = [2] + [1] * (steps - 2) frame_sizes = [1] * steps for frame_size in frame_sizes: frame = x[:, :frame_size] x = x[:, frame_size:] outs.append(tr_mem_efficient(frame)) out = torch.cat(outs, dim=1) delta = torch.norm(out - ref) / torch.norm(out) assert delta < 1e-6, delta def test_cross_attention(): torch.manual_seed(1234) for norm_first in [True, False]: m = StreamingTransformer( 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True) m_cross = StreamingTransformer( 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True) m_cross.load_state_dict(m.state_dict(), strict=False) x = torch.randn(2, 5, 16) cross_x = torch.randn(2, 3, 16) y_ref = m(x) y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x) # With norm_first, the two should be exactly the same, # but with norm_first=False, we get 2 normalization in a row # and the epsilon value leads to a tiny change. atol = 0. if norm_first else 1e-6 print((y_ref - y_cross_zero).norm() / y_ref.norm()) assert torch.allclose(y_ref, y_cross_zero, atol=atol) # We now expect a difference even with a generous atol of 1e-2. y_cross = m_cross(x, cross_attention_src=cross_x) assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2) with pytest.raises(AssertionError): _ = m_cross(x) _ = m(x, cross_attention_src=cross_x) def test_cross_attention_compat(): torch.manual_seed(1234) num_heads = 2 dim = num_heads * 64 with pytest.raises(AssertionError): StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True) cross_attn = StreamingMultiheadAttention( dim, num_heads, dropout=0, cross_attention=True, custom=True) ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True) # We can load the regular attention state dict # so we have compat when loading old checkpoints. cross_attn.load_state_dict(ref_attn.state_dict()) queries = torch.randn(3, 7, dim) keys = torch.randn(3, 9, dim) values = torch.randn(3, 9, dim) y = cross_attn(queries, keys, values)[0] y_ref = ref_attn(queries, keys, values)[0] assert torch.allclose(y, y_ref, atol=1e-7), (y - y_ref).norm() / y_ref.norm() # Now let's check that streaming is working properly. with cross_attn.streaming(): ys = [] for step in range(queries.shape[1]): ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0]) y_streaming = torch.cat(ys, dim=1) assert torch.allclose(y_streaming, y, atol=1e-7) def test_repeat_kv(): torch.manual_seed(1234) num_heads = 8 kv_repeat = 4 dim = num_heads * 64 with pytest.raises(AssertionError): mha = StreamingMultiheadAttention( dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True) mha = StreamingMultiheadAttention( dim, num_heads, causal=True, kv_repeat=kv_repeat) mha = StreamingMultiheadAttention( dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True) x = torch.randn(4, 18, dim) y = mha(x, x, x)[0] assert x.shape == y.shape def test_qk_layer_norm(): torch.manual_seed(1234) tr = StreamingTransformer( 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False) steps = 12 x = torch.randn(3, steps, 16) y = tr(x) tr = StreamingTransformer( 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True) z = torch.randn(3, 21, 16) y = tr(x, cross_attention_src=z) assert y.shape == x.shape ================================================ FILE: tests/quantization/test_vq.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from audiocraft.quantization.vq import ResidualVectorQuantizer class TestResidualVectorQuantizer: def test_rvq(self): x = torch.randn(1, 16, 2048) vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8) res = vq(x, 1.) assert res.x.shape == torch.Size([1, 16, 2048]) ================================================ FILE: tests/utils/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree.