Repository: crlandsc/tiny-audio-diffusion Branch: main Commit: 9847b7ceab8a Files: 15 Total size: 64.2 KB Directory structure: gitextract_v0oqny1i/ ├── .gitattributes ├── .gitignore ├── Inference.ipynb ├── LICENSE ├── README.md ├── config.yaml ├── data/ │ └── wav_dataset/ │ └── .gitkeep ├── exp/ │ ├── drum_diffusion.yaml │ └── drum_diffusion_no_wandb.yaml ├── main/ │ ├── diffusion_module.py │ └── utils.py ├── saved_models/ │ └── .gitkeep ├── setup/ │ ├── environment.yml │ └── requirements.txt └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ # Python .gitattributes (modified from: https://github.com/alexkaratarakis/gitattributes) # Source files # ============ *.pxd text diff=python *.py text diff=python *.py3 text diff=python *.pyw text diff=python *.pyx text diff=python *.pyz text diff=python *.pyi text diff=python # Binary files # ============ *.db binary *.p binary *.pkl binary *.pickle binary *.pyc binary export-ignore *.pyo binary export-ignore *.pyd binary # Python files # ============ *.py linguist-language=Python # Jupyter Notebook # ============ *.ipynb linguist-language=Jupyter Notebook *.ipynb text eol=lf ================================================ FILE: .gitignore ================================================ # Custom ignore __pycache__ .mypy_cache .env .DS_Store .DS_Store/ .hydra venv/ logs/ .vscode/ *Zone.Identifier kicks/ snares/ hihats/ claps/ snaps/ cymbals/ rides/ toms/ percussion/ archive/ ignore/ video_samples/ # Python Template # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ ================================================ FILE: Inference.ipynb ================================================ { "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "1fc06181", "metadata": {}, "source": [ "# Inference Notebook\n", "- This notebook serves to import trained models and generate new samples.\n", "- You should only need to edits the [Checkpoint & Configs](#Checkpoint-\\&-Configs) and [Define Sample Parameters](#Define-Sample-Parameters) cells.\n", "- Currently this notebook only offers unconditional generation, but I plan to include more features in the future.\n", "- Have fun creating new sounds!\n", "\n", "*(NOTE: Don't do \"run all\" in Jupyter. For some reason it doesn't output anything when I used this option. It may work in other environments, but just a heads up!)*" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7b67d821", "metadata": {}, "source": [ "#### Imports\n", "Import necessary libraries to run the notebook" ] }, { "cell_type": "code", "execution_count": null, "id": "3dc381cc", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Imports\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torchaudio\n", "from torch import nn\n", "import pytorch_lightning as pl\n", "from ema_pytorch import EMA\n", "import IPython.display as ipd\n", "import yaml\n", "from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler\n", "from diffusion import sampling, utils" ] }, { "attachments": {}, "cell_type": "markdown", "id": "49d620cb", "metadata": {}, "source": [ "### Checkpoint & Configs\n", "- Replace these paths with the path to your model's checkpoint and configs.\n", "- Pre-trained models are availlable to download on Hugging Face.\n", "\n", "|Model|Link|\n", "|---|---|\n", "|Kicks|[crlandsc/tiny-audio-diffusion-kicks](https://huggingface.co/crlandsc/tiny-audio-diffusion-kicks)|\n", "|Snares|[crlandsc/tiny-audio-diffusion-snares](https://huggingface.co/crlandsc/tiny-audio-diffusion-snares)|\n", "|Hi-hats|[crlandsc/tiny-audio-diffusion-hihats](https://huggingface.co/crlandsc/tiny-audio-diffusion-hihats)|\n", "|Percussion (all drum types)|[crlandsc/tiny-audio-diffusion-percussion](https://huggingface.co/crlandsc/tiny-audio-diffusion-percussion)|" ] }, { "cell_type": "code", "execution_count": null, "id": "5037ead6", "metadata": {}, "outputs": [], "source": [ "# Load model checkpoint\n", "ckpt_path = \"./saved_models/kicks/kicks_v7.ckpt\" # path to model checkpoint\n", "config_path = \"./saved_models/kicks/config.yaml\" # path to model config" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f2f61804", "metadata": {}, "source": [ "### Functions & Models\n", "- Functions and models definitions" ] }, { "cell_type": "code", "execution_count": null, "id": "46eec999", "metadata": {}, "outputs": [], "source": [ "# Load configs\n", "with open(config_path, 'r') as file:\n", " config = yaml.safe_load(file)\n", "pl_configs = config['model']\n", "model_configs = config['model']['model']" ] }, { "cell_type": "code", "execution_count": null, "id": "f4797122", "metadata": {}, "outputs": [], "source": [ "def plot_mel_spectrogram(sample):\n", " transform = torchaudio.transforms.MelSpectrogram(\n", " sample_rate=sr,\n", " n_fft=1024,\n", " hop_length=512,\n", " n_mels=80,\n", " center=True,\n", " norm=\"slaney\",\n", " )\n", "\n", " spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram\n", " spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)\n", "\n", " # Plot the Mel spectrogram\n", " fig = plt.figure(figsize=(7, 4))\n", " plt.imshow(spectrogram, aspect='auto', origin='lower')\n", " plt.colorbar(format='%+2.0f dB')\n", " plt.xlabel('Frame')\n", " plt.ylabel('Mel Bin')\n", " plt.title('Mel Spectrogram')\n", " plt.tight_layout()\n", " \n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "e246c0e2", "metadata": {}, "outputs": [], "source": [ "# Define PyTorch Lightning model\n", "class Model(pl.LightningModule):\n", " def __init__(\n", " self,\n", " lr: float,\n", " lr_beta1: float,\n", " lr_beta2: float,\n", " lr_eps: float,\n", " lr_weight_decay: float,\n", " ema_beta: float,\n", " ema_power: float,\n", " model: nn.Module,\n", " ):\n", " super().__init__()\n", " self.lr = lr\n", " self.lr_beta1 = lr_beta1\n", " self.lr_beta2 = lr_beta2\n", " self.lr_eps = lr_eps\n", " self.lr_weight_decay = lr_weight_decay\n", " self.model = model\n", " self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5b2aecab", "metadata": {}, "source": [ "### Instantiate model\n", "*NOTE: This model setup needs to exactly match the model that was trained*\n", "\n", "- This cell instantiates the model (no weights) using the config.yaml file. This is structure is critical to make sure that the model weights can be loaded in correctly." ] }, { "cell_type": "code", "execution_count": null, "id": "d626c6e7", "metadata": {}, "outputs": [], "source": [ "# Instantiate model (must match model that was trained)\n", "\n", "# Diffusion model\n", "model = DiffusionModel(\n", " net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)\n", " in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels\n", " channels=model_configs['channels'], # U-Net: channels at each layer\n", " factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer\n", " items=model_configs['items'], # U-Net: number of repeating items at each layer\n", " attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer\n", " attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item\n", " attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item\n", " diffusion_t=VDiffusion, # The diffusion method used\n", " sampler_t=VSampler # The diffusion sampler used\n", ")\n", "\n", "# pl model\n", "model = Model(\n", " lr=pl_configs['lr'],\n", " lr_beta1=pl_configs['lr_beta1'],\n", " lr_beta2=pl_configs['lr_beta2'],\n", " lr_eps=pl_configs['lr_eps'],\n", " lr_weight_decay=pl_configs['lr_weight_decay'],\n", " ema_beta=pl_configs['ema_beta'],\n", " ema_power=pl_configs['ema_power'],\n", " model=model\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c2d2f702", "metadata": {}, "source": [ "### Check if GPU available\n", "- This checks to see if a CUDe capable GPU is available to utilize. If so, the model is assigned to the GPU. If not, the model simply remains on the CPU." ] }, { "cell_type": "code", "execution_count": null, "id": "9ce84487", "metadata": {}, "outputs": [], "source": [ "# Assign to GPU\n", "if torch.cuda.is_available():\n", " model = model.to('cuda')\n", " print(f\"Device: {model.device}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "825d96d9", "metadata": {}, "source": [ "### Load model\n", "- This cell loads the checkpoint weights into the model. It should return `\"\"` if successfully loaded." ] }, { "cell_type": "code", "execution_count": null, "id": "845993e4", "metadata": {}, "outputs": [], "source": [ "# Load model checkpoint\n", "checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict']\n", "model.load_state_dict(checkpoint) # should output \"\"" ] }, { "attachments": {}, "cell_type": "markdown", "id": "aaaa996f", "metadata": {}, "source": [ "## Unconditional Sample Generation\n", "Generate new sounds from noise with no conditioning." ] }, { "attachments": {}, "cell_type": "markdown", "id": "c93dc280", "metadata": {}, "source": [ "#### Define Sample Parameters\n", "- sample_length: how long to make the output (measured in samples). Recommended $2^{15}=32768$ (~0.75 sec), as that is what the model was trained on.\n", "- sr (sample rate): sampling rate to output. Recommended industry standard 44.1kHz (44100Hz).\n", "- num_samples: number of new samples that will be generated.\n", "- num_steps: number of diffusion steps - tradeoff inference speed for sample quality (10-100 is a good range).\n", " - 10+ steps - quick generation, alright samples but noticeable high-freq hiss.\n", " - 50+ steps - moderate generation speed, good tradeoff for speed and qualiy (less high-freq hiss).\n", " - 100+ steps - slow generation speed, high quality samples.\n", "\n", "Have fun playing around with these parameters! Note that sometimes the model outputs some wild things. This is likely due to the small size of the model as well as the limited training data. Larger models and/or larger and more diverse datasets would improve this." ] }, { "cell_type": "code", "execution_count": null, "id": "6f0438d4", "metadata": {}, "outputs": [], "source": [ "# Define diffusion paramters\n", "sample_length = 2**15 # 32768 samples @ 44100 = .75 sec\n", "sr = 44100\n", "num_samples = 3 # number of samples to generate\n", "num_steps = 50 # number of diffusion steps, tradeoff inference speed for sample quality (10-100 is a good range)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4952b429", "metadata": {}, "source": [ "#### Generate samples\n", "Run the following cell to generate samples based on previously defined parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "bf88849c", "metadata": { "scrolled": false }, "outputs": [], "source": [ "with torch.no_grad():\n", " all_samples = torch.zeros(2, 0)\n", " for i in range(num_samples):\n", " noise = torch.randn((1, 2, sample_length), device=model.device) # [batch_size, in_channels, length]\n", " generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100\n", "\n", " print(f\"Generated Sample {i+1}\")\n", " display(ipd.Audio(generated_sample, rate=sr))\n", " \n", " # concatenate all samples:\n", " all_samples = torch.concat((all_samples, generated_sample), dim=1)\n", " \n", " fig = plot_mel_spectrogram(generated_sample)\n", " plt.title(f\"Mel Spectrogram (Sample {i+1})\")\n", " plt.show()\n", " \n", " torch.cuda.empty_cache()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c18bda2c", "metadata": {}, "source": [ "#### Combine all samples\n", "- Option to combine all samples into a single sample" ] }, { "cell_type": "code", "execution_count": null, "id": "f8003140", "metadata": {}, "outputs": [], "source": [ "# Optional concatenate all samples\n", "print(f\"All Samples\")\n", "display(ipd.Audio(all_samples, rate=sr))\n", "fig = plot_mel_spectrogram(all_samples)\n", "plt.title(f\"Mel Spectrogram)\")\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9b5cc6ca", "metadata": {}, "source": [ "## Conditional \"Style-Transfer\" Generation\n", "Generate new sounds conditioned on input audio.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f122e399", "metadata": {}, "source": [ "#### Define Sample Parameters\n", "- audio_file_path: Path to audio file for conditioning the model.\n", "- sample_with_noise: Option to output the conditioning sample with noise added to listen, or suppress it.\n", "- trim_sample: Option to trim/pad sample if it is too long/short.\n", "- sample_length: how long to make the output (measured in samples). Recommended $2^{15}=32768$ (~0.75 sec), as that is what the model was trained on.\n", "- sr (sample rate): sampling rate to output. Recommended industry standard 44.1kHz (44100Hz).\n", "- num_samples: number of new samples that will be generated.\n", "- noise_level: The amount of noise to be added to the input sample.\n", "- num_steps: number of diffusion steps - tradeoff inference speed for sample quality.\n", " - The number of steps for conditional diffusion varies more compared to unconditional diffusion. For example, if you input a transient sound (like a snare hit) and want to transfer it to the `kicks` model, then you may not want to add any noise and keep the steps below 10 for an interesting sound. But, if you want to transfer something like a guitar to the percussion model, you may want to add some more noise and increase the number of steps.\n", "\n", "*NOTE:* The less noise that is added to a sample, the less varied the outputs will be. For example, if you ad 0 noise to a sample and generate it 3 times, it will produce the exact same thing 3 times (because the input remains consistent). As you increase the noise added, increasing the variation of the inputs, the outputs will vary more widely as well." ] }, { "cell_type": "code", "execution_count": null, "id": "38a62339", "metadata": {}, "outputs": [], "source": [ "# Define diffusion paramters\n", "audio_file_path = \"samples/snare1.wav\"\n", "\n", "# Listen to noised sample\n", "sample_with_noise = False # True to listen to sample + noise, false to not output\n", "\n", "# If sample too long\n", "trim_sample = False # True - if sample too long / False does not trim\n", "sample_length = 2**15 # NA\n", "\n", "sr = 44100 # Sampling rate\n", "num_samples = 1 # number of samples to generate\n", "noise_level = 0 # between 0 and 1\n", "num_steps = 6 # number of diffusion steps, tradeoff inference speed for sample quality (10-100 is a good range)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "15298646", "metadata": {}, "source": [ "#### Generate samples\n", "Run the following cell to generate samples based on previously defined parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "5bf5203d", "metadata": {}, "outputs": [], "source": [ "# Generate samples\n", "with torch.no_grad():\n", "\n", " # load audio sample\n", " audio_sample = torchaudio.load(audio_file_path)[0].unsqueeze(0).to(model.device) # unsqueeze for correct tensor shape\n", "\n", " # Trim audio\n", " if trim_sample:\n", " og_shape = audio_sample.shape\n", " if sample_length < og_shape[2]:\n", " audio_sample = audio_sample[:,:,:sample_length]\n", " elif sample_length > og_shape[2]:\n", " # Pad tensor with zeros to match sample length\n", " audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], sample_length - og_shape[2]).to(model.device)), dim=2)\n", "\n", "\n", " original_audio = audio_sample.squeeze(0).squeeze(0).cpu()\n", "\n", " # Display original audio sample\n", " print(f\"Original Sample\")\n", " display(ipd.Audio(original_audio, rate=sr))\n", "\n", " # Plot original audio\n", " fig = plot_mel_spectrogram(original_audio)\n", " plt.title(f\"Mel Spectrogram (Original Sample)\")\n", " plt.show()\n", "\n", "\n", " # Display original audio sample + noise\n", " if sample_with_noise:\n", " noise = torch.randn_like(audio_sample, device=model.device) * noise_level # combine input signal and noise\n", " noised_sample = (audio_sample + noise).squeeze(0).cpu() # normalize?\n", " print(f\"Original Noised Sample\")\n", " display(ipd.Audio(noised_sample, rate=sr))\n", "\n", " # Plot original audio + noise\n", " fig = plot_mel_spectrogram(noised_sample)\n", " plt.title(f\"Mel Spectrogram (Noised Sample)\")\n", " plt.show()\n", "\n", "\n", " all_samples = torch.zeros(2, 0)\n", " for i in range(num_samples):\n", " noise = torch.randn_like(audio_sample, device=model.device) * noise_level # combine input signal and noise\n", " audio = audio_sample + noise # normalize?\n", " generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu()\n", "\n", " print(f\"Generated Sample {i+1}\")\n", " display(ipd.Audio(generated_sample, rate=sr))\n", " \n", " # concatenate all samples:\n", " all_samples = torch.concat((all_samples, generated_sample), dim=1)\n", " \n", " fig = plot_mel_spectrogram(generated_sample)\n", " plt.title(f\"Mel Spectrogram (Sample {i+1})\")\n", " plt.show()\n", " \n", " torch.cuda.empty_cache()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4762cc6b", "metadata": {}, "source": [ "#### Combine all samples\n", "- Option to combine all samples into a single sample" ] }, { "cell_type": "code", "execution_count": null, "id": "5501ce6a", "metadata": {}, "outputs": [], "source": [ "# Optional concatenate all samples\n", "print(f\"All Samples\")\n", "display(ipd.Audio(all_samples, rate=sr))\n", "fig = plot_mel_spectrogram(all_samples)\n", "plt.title(f\"Mel Spectrogram)\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "0ef83b4b", "metadata": {}, "outputs": [], "source": [ "# TODO: Add normalization?\n", "# TODO: Add other smapling methods (currently only DDIM)\n", "# TODO: clean cell (make functions)" ] } ], "metadata": { "kernelspec": { "display_name": "tiny-audio-diffusion (Python 3.10)", "language": "python", "name": "tiny-audio-diffusion" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Christopher Landschoot Copyright (c) 2022 archinet.ai Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

Tiny Audio Diffusion

Tiny Audio Diffusion Logo

[![Hugging Face Spaces Badge](https://img.shields.io/badge/%F0%9F%A4%97_Spaces_Demo-blue)](https://huggingface.co/spaces/crlandsc/tiny-audio-diffusion) [![YouTube Tutorial Badge](https://img.shields.io/badge/Repo_Tutorial-red?logo=YouTube)](https://youtu.be/m6Eh2srtTro) [![Towards Data Science Badge](https://img.shields.io/badge/Towards_Data_Science-red?logo=Medium&color=black)](https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b) [![GitHub License](https://img.shields.io/github/license/crlandsc/tiny-audio-diffusion)](https://github.com/crlandsc/tiny-audio-diffusion/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/tiny-audio-diffusion?color=gold)](https://github.com/crlandsc/tiny-audio-diffusion/stargazers) [![GitHub forks](https://img.shields.io/github/forks/crlandsc/tiny-audio-diffusion?color=green)](https://github.com/crlandsc/tiny-audio-diffusion/forks) This is a repository for generating short audio samples and training waveform diffusion models on a consumer-grade GPU with less than 2GB VRAM. ## Motivation The purpose of this project is to provide access to stereo high-resolution (44.1kHz) conditional and unconditional audio waveform (1D U-Net) diffusion code for those interested in exploration but who have limited resources. There are many methods for audio generation on low-level hardware, but less so specifically for waveform-based diffusion. The repository is built heavily adapting code from Archinet's [audio-diffusion-pytorch](https://github.com/archinetai/audio-diffusion-pytorch) libary. A huge thank you to [Flavio Schneider](https://github.com/flavioschneider) for his incredible open-source work in this field! ## Background Direct waveform diffusion is inherently computationally intensive. For example, an audio sample with the industry standard 44.1kHz sampling rate requires 44,100 samples for just 1 second of audio. Now multiply that by 2 for a stereo file. However, it has a significant advantage over many methods that reduce audio into spectrograms or downsample - the network retains and learns from *phase* information. Phase is challenging to represent on its own in visual methods, such as spectrograms, as it appears similar to that of random noise. Because of this, many generative methods discard phase information and then implement ways of estimating and regenerating it. However, it plays a key role in defining the timbral qualities of sounds and should not be dispensed with so easily. Waveform diffusion is able to retain this important feature as it does not perform any transforms on the audio before feeding it into the network. This is how humans perceive sounds, with both amplitude and phase information bundled together in a single signal. As mentioned previously, this comes at the expense of computational requirements and is often reserved for training on a cluster of GPUs with high speeds and lots of memory. Because of this, it is hard to begin to experiment with waveform diffusion with limited resources. This repository seeks to offer some base code to those looking to experiment with and learn more about waveform diffusion on their own computer without having to purchase cloud resources or upgrade hardware. This goes for not only *inference*, but *training* your own models as well! To make this feasible, however, there must be a tradeoff of quality, speed, and sample length. Because of this, I have focused on training base models for one-shot drum samples - as they are inherently short in sample length. The current configuration is set up to be able to train ~0.75 second stereo samples at 44.1kHz, allowing for the generation of high-quality one-shot audio samples. The network configuration can be adjusted to improve the resolution, sample rate, training and inference speed, sample length, etc. but, of course, more hardware resources will be required. Other methods of diffusion, such as diffusion in the latent space ([Stable Diffusion's](https://stability.ai/stablediffusion) secret sauce), compared to this repo's raw waveform diffusion can offer an improvement and other tradeoffs between quality, memory requirements, speed, etc. I recommend this repo to remain up-to-date with the latest research in generative audio: https://github.com/archinetai/audio-ai-timeline Also recommended is [Harmonai's](https://www.harmonai.org/) community project, [Dance Diffusion](https://github.com/Harmonai-org/sample-generator), which implements similar functionality to this repo on a larger scale with several pre-trained models. [Colab notebook](https://colab.research.google.com/github/Harmonai-org/sample-generator/blob/main/Dance_Diffusion.ipynb) available. **April 2024 update:** Some additional useful generative audio tools/repos: - [Stable Audio Tools](https://github.com/Stability-AI/stable-audio-tools) (used in [Stable Audio](https://www.stableaudio.com/)) - Useful audio tools for building and training models. - [audiocraft](https://github.com/facebookresearch/audiocraft) (used in [MusicGen](https://audiocraft.metademolab.com/musicgen.html) & [AudioGen](https://audiocraft.metademolab.com/audiogen.html)) - Useful audio tools for building and training models. - [audiomentations](https://github.com/iver56/audiomentations) - Good library for implementing audio augmentations on CPU for training. See [torch-audiomentations](https://github.com/asteroid-team/torch-audiomentations) for GPU implementation. --- ## Setup Follow these steps to set up an environment for both generating audio samples and training models. *NOTE:* To use this repo with a GPU, you must have a CUDA-capable GPU and have the CUDA toolkit installed for your specific to your system (ex. Linux, x86_64, WSL-Ubuntu). More information can be found [here](https://developer.nvidia.com/cuda-toolkit). #### 1. Create a Virtual Environment: Ensure that [Anaconda (or Miniconda)](https://docs.anaconda.com/free/anaconda/install/index.html) is installed and activated. From the command line, `cd` into the [`setup/`](setup/) folder and run the following lines: ```bash conda env create -f environment.yml conda activate tiny-audio-diffusion ``` This will create and activate a conda environment from the [`setup/environment.yml`](setup/environment.yml) file and install the dependencies in [`setup/requirements.txt`](setup/requirements.txt). #### 2. Install Python Kernel For Jupyter Notebook Run the following line to create a kernel for the current environment to run the inference notebook. ```bash python -m ipykernel install --user --name tiny-audio-diffusion --display-name "tiny-audio-diffusion (Python 3.10)" ``` #### 3. Define Environment Variables Rename [`.env.tmp`](.env.tmp) to `.env` and replace the entries with your own variables (example values are random). ```bash DIR_LOGS=/logs DIR_DATA=/data # Required if using Weights & Biases (W&B) logger WANDB_PROJECT=tiny_drum_diffusion # Custom W&B name for current project WANDB_ENTITY=johnsmith # W&B username WANDB_API_KEY=a21dzbqlybbzccqla4txa21dzbqlybbzccqla4tx # W&B API key ``` *NOTE:* Sign up for a [Weights & Biases](https://wandb.ai/site) account to log audio samples, spectrograms, and other metrics while training (it's free!). W&B logging example for this repo [here](https://wandb.ai/crlandsc/unconditional-drum-diffusion?workspace=user-crlandsc). --- ## Pre-trained Models Pretrained models can be found on Hugging Face (each model contains a `.ckpt` and `.yaml` file): |Model|Link| |---|---| |Kicks|[crlandsc/tiny-audio-diffusion-kicks](https://huggingface.co/crlandsc/tiny-audio-diffusion-kicks)| |Snares|[crlandsc/tiny-audio-diffusion-snares](https://huggingface.co/crlandsc/tiny-audio-diffusion-snares)| |Hi-hats|[crlandsc/tiny-audio-diffusion-hihats](https://huggingface.co/crlandsc/tiny-audio-diffusion-hihats)| |Percussion (all drum types)|[crlandsc/tiny-audio-diffusion-percussion](https://huggingface.co/crlandsc/tiny-audio-diffusion-percussion)| *See W&B model training metrics [here](https://wandb.ai/crlandsc/unconditional-drum-diffusion?workspace=user-crlandsc).* Pre-trained models can be downloaded to generate samples via the [inference notebook](Inference.ipynb). They can also be used as a base model to fine-tune on custom data. It is recommended to create subfolders within the [`saved_models`](saved_models/) folder to store each model's `.ckpt` and `.yaml` files. --- ## Inference ### Hugging Face Spaces Generate samples without code on [🤗 Hugging Face Spaces](https://huggingface.co/spaces/crlandsc/tiny-audio-diffusion)! ### Jupyter Notebook #### Audio Sample Generation Current Capabilities: - Unconditional Generation - Conditional "Style-transfer" Generation Open the [`Inference.ipynb`](Inference.ipynb) in Jupyter Notebook and follow the instructions to generate new audio samples. Ensure that the `"tiny-audio-diffusion (Python 3.10)"` kernel is active in Jupyter to run the notebook and you have downloaded the [pre-trained model](#Pre\-trained-Models) of interest from Hugging Face. --- ## Train The model architecture has been constructed with [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/) and [Hydra](https://hydra.cc/docs/intro/) frameworks. All configurations for the model are contained within `.yaml` files and should be edited there rather than hardcoded. [`exp/drum_diffusion.yaml`](exp/drum_diffusion.yaml) contains the default model configuration. Additional custom model configurations can be added to the [`exp`](exp/) folder. Custom models can be trained or fine-tuned on custom datasets. Datasets should consist of a folder of `.wav` audio files with a 44.1kHz sampling rate. To train or finetune models, run one of the following commands in the terminal from the repo's root folder and replace `` with the path to your custom training set. **Train model from scratch (on CPU):** *(not recommended)* ```bash python train.py exp=drum_diffusion datamodule.dataset.path= ``` **Train model from scratch (on GPU):** ```bash python train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path= ``` *NOTE:* To use this repo with a GPU, you must have a CUDA-capable GPU and have the CUDA toolkit installed specific to your system (ex. Linux, x86_64, WSL-Ubuntu). More information can be found [here](https://developer.nvidia.com/cuda-toolkit). **Resume run from a checkpoint (with GPU):** ```bash python train.py exp=drum_diffusion trainer.gpus=1 +ckpt= datamodule.dataset.path= ``` --- ## Dataset The data used to train the checkpoints listed above can be found on [🤗 Hugging Face](https://huggingface.co/datasets/crlandsc/tiny-audio-diffusion-drums). ***Note:*** *This is a small and unbalanced dataset consisting of free samples that I had from my music production. These samples are not covered under the MIT license of this repository and cannot be used to train any commercial models, but can be used in personal and research contexts.* ***Note:*** *For appropriately diverse models, larger datasets should be used to avoid memorization of training data.* --- ## Repository Structure The structure of this repository is as follows: ``` ├── main │ ├── diffusion_module.py - contains pl model, data loading, and logging functionalities for training │ └── utils.py - contains utility functions for training ├── exp │ └── *.yaml - Hydra configuration files ├── setup │ ├── environment.yml - file to set up conda environment │ └── requirements.txt - contains repo dependencies ├── images - directory containing images for README.md │ └── *.png ├── samples - directory containing sample outputs from tiny-audio-diffusion models │ └── *.wav ├── .env.tmp - temporary environment variables (rename to .env) ├── .gitignore ├── README.md ├── Inference.ipynb - Jupyter notebook for running inference to generate new samples ├── config.yaml - Hydra base configs ├── train.py - script for training ├── data - directory to host custom training data │ └── wav_dataset │ └── (*.wav) └── saved_models - directory to host model checkpoints and hyper-parameters for inference └── (kicks/snare/etc.) ├── (*.ckpt) - pl model checkpoint file └── (config.yaml) - pl model hydra hyperparameters (required for inference) ``` ================================================ FILE: config.yaml ================================================ defaults: - _self_ - exp: null # config to load - override hydra/hydra_logging: colorlog - override hydra/job_logging: colorlog seed: 12345 train: True ignore_warnings: True print_config: False # Prints tree with all configurations work_dir: ${hydra:runtime.cwd} # This is the root of the project logs_dir: ${work_dir}${oc.env:DIR_LOGS} # This is the root for all logs data_dir: ${work_dir}${oc.env:DIR_DATA} # This is the root for all data ckpt_dir: ${logs_dir}/runs/${now:%Y-%m-%d-%H-%M-%S} # Hydra experiment configs log dir hydra: run: dir: ${ckpt_dir} # save in same dir as ckpts ================================================ FILE: data/wav_dataset/.gitkeep ================================================ ================================================ FILE: exp/drum_diffusion.yaml ================================================ # @package _global_ # Unconditional Audio Waveform Diffusion # To execute this experiment on a single GPU, run: # python train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path= module: main.diffusion_module batch_size: 1 # mini-batch size (increase to speed up at the cost of memory) accumulate_grad_batches: 32 # use to increase batch size on single GPU -> effective batch size = (batch_size * accumulate_grad_batches) num_workers: 8 # num workers for data loading sampling_rate: 44100 # sampling rate (44.1kHz is the music industry standard) length: 32768 # Length of audio in samples (32768 samples @ 44.1kHz ~ 0.75 seconds) channels: 2 # stereo audio val_log_every_n_steps: 1000 # Logging interval (Validation and audio generation every n steps) # ckpt_every_n_steps: 4000 # Use if multiple checkpoints wanted model: _target_: ${module}.Model # pl model wrapper lr: 1e-4 # optimizer learning rate lr_beta1: 0.95 # beta1 param for Adam optimizer lr_beta2: 0.999 # beta2 param for Adam optimzer lr_eps: 1e-6 # epsilon for optimizer (to avoid div by 0) lr_weight_decay: 1e-3 # weight decay regularization param ema_beta: 0.995 # EMA model (exponential-moving-average) beta ema_power: 0.7 # EMA model gradiaent norm param model: _target_: audio_diffusion_pytorch.DiffusionModel # Waveform diffusion model net_t: _target_: ${module}.UNetT # The model type used for diffusion (U-Net V0 in this case) in_channels: 2 # U-Net: number of input/output (audio) channels channels: [32, 32, 64, 64, 128, 128, 256, 256] # U-Net: channels at each layer factors: [1, 2, 2, 2, 2, 2, 2, 2] # U-Net: downsampling and upsampling factors at each layer items: [2, 2, 2, 2, 2, 2, 4, 4] # U-Net: number of repeating items at each layer attentions: [0, 0, 0, 0, 0, 1, 1, 1] # U-Net: attention enabled/disabled at each layer attention_heads: 8 # U-Net: number of attention heads per attention item attention_features: 64 # U-Net: number of attention features per attention item # To specify train-valid datasets, datamodule must be reconfigured datamodule: _target_: main.diffusion_module.Datamodule dataset: _target_: audio_data_pytorch.WAVDataset path: ./data/wav_dataset # can overried when calling train.py recursive: True sample_rate: ${sampling_rate} transforms: _target_: audio_data_pytorch.AllTransform crop_size: ${length} # One-shots, so no random crop stereo: True source_rate: ${sampling_rate} target_rate: ${sampling_rate} loudness: -20 # normalize loudness val_split: 0.1 # split data into validation batch_size: ${batch_size} num_workers: ${num_workers} pin_memory: True callbacks: rich_progress_bar: _target_: pytorch_lightning.callbacks.RichProgressBar # _target_: pytorch_lightning.callbacks.TQDMProgressBar # use if RichProgressBar creates issues model_checkpoint: _target_: pytorch_lightning.callbacks.ModelCheckpoint monitor: "valid_loss" # name of the logged metric which determines when model is improving save_top_k: 1 # save k best models (determined by above metric) save_last: True # additionaly always save model from last epoch mode: "min" # can be "max" or "min" verbose: False dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} filename: '{epoch:02d}-{valid_loss:.3f}' # every_n_train_steps: ${ckpt_every_n_steps} # Use if multiple checkpoints wanted model_summary: _target_: pytorch_lightning.callbacks.RichModelSummary max_depth: 2 audio_samples_logger: _target_: main.diffusion_module.SampleLogger num_items: 4 # number of separate samples to be generated channels: ${channels} # number of audio channels sampling_rate: ${sampling_rate} # audio sampling rate length: ${length} # length of generated sample sampling_steps: [50] # number of steps per sample use_ema_model: True # Use EMA for logger inference loggers: wandb: _target_: pytorch_lightning.loggers.wandb.WandbLogger project: ${oc.env:WANDB_PROJECT} # defined in env var entity: ${oc.env:WANDB_ENTITY} # defined in env var name: unconditional_diffusion # name of run # offline: False # set True to store all logs only locally job_type: "train" group: "" # Set a group name if desired save_dir: ${logs_dir} trainer: _target_: pytorch_lightning.Trainer gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0` precision: 16 # Precision used for tensors (`32` offers higher precision, but `16` is used to save memory) min_epochs: 0 # minimum number of epochs max_epochs: -1 # max number of epochs (-1 = infinite run) enable_model_summary: False log_every_n_steps: 1 # Logs training metrics every n steps # limit_val_batches: 10 # Use to limit the number of valid batches run (e.g. 10 stops training at 10 batches) check_val_every_n_epoch: null val_check_interval: ${val_log_every_n_steps} # Validation interval (check valid set and generate audio every n steps) accumulate_grad_batches: ${accumulate_grad_batches} # use to increase batch size on single GPU ================================================ FILE: exp/drum_diffusion_no_wandb.yaml ================================================ # @package _global_ # Unconditional Audio Waveform Diffusion # To execute this experiment on a single GPU, run: # python train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path= module: main.diffusion_module batch_size: 1 # mini-batch size (increase to speed up at the cost of memory) accumulate_grad_batches: 32 # use to increase batch size on single GPU -> effective batch size = (batch_size * accumulate_grad_batches) num_workers: 8 # num workers for data loading sampling_rate: 44100 # sampling rate (44.1kHz is the music industry standard) length: 32768 # Length of audio in samples (32768 samples @ 44.1kHz ~ 0.75 seconds) channels: 2 # stereo audio val_log_every_n_steps: 1000 # Logging interval (Validation and audio generation every n steps) # ckpt_every_n_steps: 4000 # Use if multiple checkpoints wanted model: _target_: ${module}.Model # pl model wrapper lr: 1e-4 # optimizer learning rate lr_beta1: 0.95 # beta1 param for Adam optimizer lr_beta2: 0.999 # beta2 param for Adam optimzer lr_eps: 1e-6 # epsilon for optimizer (to avoid div by 0) lr_weight_decay: 1e-3 # weight decay regularization param ema_beta: 0.995 # EMA model (exponential-moving-average) beta ema_power: 0.7 # EMA model gradiaent norm param model: _target_: audio_diffusion_pytorch.DiffusionModel # Waveform diffusion model net_t: _target_: ${module}.UNetT # The model type used for diffusion (U-Net V0 in this case) in_channels: 2 # U-Net: number of input/output (audio) channels channels: [32, 32, 64, 64, 128, 128, 256, 256] # U-Net: channels at each layer factors: [1, 2, 2, 2, 2, 2, 2, 2] # U-Net: downsampling and upsampling factors at each layer items: [2, 2, 2, 2, 2, 2, 4, 4] # U-Net: number of repeating items at each layer attentions: [0, 0, 0, 0, 0, 1, 1, 1] # U-Net: attention enabled/disabled at each layer attention_heads: 8 # U-Net: number of attention heads per attention item attention_features: 64 # U-Net: number of attention features per attention item # To specify train-valid datasets, datamodule must be reconfigured datamodule: _target_: main.diffusion_module.Datamodule dataset: _target_: audio_data_pytorch.WAVDataset path: ./data/wav_dataset # can overried when calling train.py recursive: True sample_rate: ${sampling_rate} transforms: _target_: audio_data_pytorch.AllTransform crop_size: ${length} # One-shots, so no random crop stereo: True source_rate: ${sampling_rate} target_rate: ${sampling_rate} loudness: -20 # normalize loudness val_split: 0.1 # split data into validation batch_size: ${batch_size} num_workers: ${num_workers} pin_memory: True callbacks: rich_progress_bar: _target_: pytorch_lightning.callbacks.RichProgressBar # _target_: pytorch_lightning.callbacks.TQDMProgressBar # use if RichProgressBar creates issues model_checkpoint: _target_: pytorch_lightning.callbacks.ModelCheckpoint monitor: "valid_loss" # name of the logged metric which determines when model is improving save_top_k: 1 # save k best models (determined by above metric) save_last: True # additionaly always save model from last epoch mode: "min" # can be "max" or "min" verbose: False dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} filename: '{epoch:02d}-{valid_loss:.3f}' # every_n_train_steps: ${ckpt_every_n_steps} # Use if multiple checkpoints wanted model_summary: _target_: pytorch_lightning.callbacks.RichModelSummary max_depth: 2 trainer: _target_: pytorch_lightning.Trainer gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0` precision: 16 # Precision used for tensors (`32` offers higher precision, but `16` is used to save memory) min_epochs: 0 # minimum number of epochs max_epochs: -1 # max number of epochs (-1 = infinite run) enable_model_summary: False log_every_n_steps: 1 # Logs training metrics every n steps # limit_val_batches: 10 # Use to limit the number of valid batches run (e.g. 10 stops training at 10 batches) check_val_every_n_epoch: null val_check_interval: ${val_log_every_n_steps} # Validation interval (check valid set and generate audio every n steps) accumulate_grad_batches: ${accumulate_grad_batches} # use to increase batch size on single GPU ================================================ FILE: main/diffusion_module.py ================================================ # This code has been adapted from Flavio Schneider's work with Archinet. # (https://github.com/archinetai/audio-diffusion-pytorch-trainer) from audio_data_pytorch.utils import fractional_random_split from pytorch_lightning.loggers import LoggerCollection, WandbLogger from audio_diffusion_pytorch import UNetV0, VDiffusion, VSampler, LTPlugin import random from typing import Any, List, Optional import plotly.graph_objs as go import pytorch_lightning as pl import torch import torchaudio import wandb from einops import rearrange from ema_pytorch import EMA from pytorch_lightning import Callback, Trainer from torch import Tensor, nn from torch.utils.data import DataLoader """ Model """ # Option to use learned transform to downsample (by stride length) input data (not recommended). # Can reduce computational load, but introduces undesirable high freq artifacts. UNetT_LT = lambda: LTPlugin(UNetV0, num_filters=32, window_length=16, stride=16) UNetT = lambda: UNetV0 # define Unet to be used (from audio_diffusion_pytorch) DiffusionT = VDiffusion # define diffusion method to be used (from audio_diffusion_pytorch) SamplerT = VSampler # define diffusion sampler to be used (from audio_diffusion_pytorch) def dropout(proba: float): return random.random() < proba class Model(pl.LightningModule): def __init__( self, lr: float, lr_beta1: float, lr_beta2: float, lr_eps: float, lr_weight_decay: float, ema_beta: float, ema_power: float, model: nn.Module, ): super().__init__() self.lr = lr self.lr_beta1 = lr_beta1 self.lr_beta2 = lr_beta2 self.lr_eps = lr_eps self.lr_weight_decay = lr_weight_decay self.model = model self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power) @property def device(self): return next(self.model.parameters()).device def configure_optimizers(self): optimizer = torch.optim.AdamW( list(self.model.parameters()), lr=self.lr, betas=(self.lr_beta1, self.lr_beta2), eps=self.lr_eps, weight_decay=self.lr_weight_decay, ) return optimizer def training_step(self, batch, batch_idx): wave = batch loss = self.model(wave) self.log("train_loss", loss, sync_dist=True) # Update EMA model and log decay self.model_ema.update() self.log("ema_decay", self.model_ema.get_current_decay(), sync_dist=True) return loss def validation_step(self, batch, batch_idx): wave = batch loss = self.model_ema(wave) self.log("valid_loss", loss, sync_dist=True) return loss """ Datamodule """ class Datamodule(pl.LightningDataModule): def __init__( self, dataset, *, val_split: float, batch_size: int, num_workers: int, pin_memory: bool = False, **kwargs: int, ) -> None: super().__init__() self.dataset = dataset self.val_split = val_split self.batch_size = batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.data_train: Any = None self.data_val: Any = None def setup(self, stage: Any = None) -> None: split = [1.0 - self.val_split, self.val_split] self.data_train, self.data_val = fractional_random_split(self.dataset, split) def get_dataloader(self, dataset) -> DataLoader: return DataLoader( dataset=dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=True, prefetch_factor=2, ) def train_dataloader(self) -> DataLoader: return self.get_dataloader(self.data_train) def val_dataloader(self) -> DataLoader: return self.get_dataloader(self.data_val) """ Callbacks """ def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]: """Safely get Weights&Biases logger from Trainer.""" if isinstance(trainer.logger, WandbLogger): return trainer.logger if isinstance(trainer.logger, LoggerCollection): for logger in trainer.logger: if isinstance(logger, WandbLogger): return logger print("WandbLogger not found.") return None def log_wandb_audio_batch( logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = "" ): num_items = samples.shape[0] samples = rearrange(samples, "b c t -> b t c").detach().cpu().numpy() logger.log( { f"sample_{idx}_{id}": wandb.Audio( samples[idx], caption=caption, sample_rate=sampling_rate, ) for idx in range(num_items) } ) def log_wandb_audio_spectrogram( logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = "" ): num_items = samples.shape[0] samples = samples.detach().cpu() transform = torchaudio.transforms.MelSpectrogram( sample_rate=sampling_rate, n_fft=1024, hop_length=512, n_mels=80, center=True, norm="slaney", ) def get_spectrogram_image(x): spectrogram = transform(x[0]) image = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0) trace = [go.Heatmap(z=image, colorscale="viridis")] layout = go.Layout( yaxis=dict(title="Mel Bin (Log Frequency)"), xaxis=dict(title="Frame"), title_font_size=10, ) fig = go.Figure(data=trace, layout=layout) return fig logger.log( { f"mel_spectrogram_{idx}_{id}": get_spectrogram_image(samples[idx]) for idx in range(num_items) } ) class SampleLogger(Callback): def __init__( self, num_items: int, channels: int, sampling_rate: int, sampling_steps: List[int], use_ema_model: bool, length: int, ) -> None: self.num_items = num_items self.channels = channels self.sampling_rate = sampling_rate self.sampling_steps = sampling_steps self.use_ema_model = use_ema_model self.log_next = False self.length = length def on_validation_epoch_start(self, trainer, pl_module): self.log_next = True def on_validation_batch_start( self, trainer, pl_module, batch, batch_idx, dataloader_idx ): if self.log_next and trainer.logger: # only log if logger present in config self.log_sample(trainer, pl_module, batch) self.log_next = False @torch.no_grad() def log_sample(self, trainer, pl_module, batch): is_train = pl_module.training if is_train: pl_module.eval() # Get wandb logger wandb_logger = get_wandb_logger(trainer).experiment model = pl_module.model if self.use_ema_model: model = pl_module.model_ema.ema_model # Get noise for diffusion inference noise = torch.randn( (self.num_items, self.channels, self.length), device=pl_module.device ) for steps in self.sampling_steps: samples = model.sample( noise, num_steps=steps, ) log_wandb_audio_batch( logger=wandb_logger, id="sample", samples=samples, sampling_rate=self.sampling_rate, caption=f"Sampled in {steps} steps", ) log_wandb_audio_spectrogram( logger=wandb_logger, id="sample", samples=samples, sampling_rate=self.sampling_rate, caption=f"Sampled in {steps} steps", ) if is_train: pl_module.train() ================================================ FILE: main/utils.py ================================================ import logging import os import warnings from typing import Callable, List, Optional, Sequence import pkg_resources # type: ignore import pytorch_lightning as pl import rich.syntax import rich.tree import torch from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_only """ Training Utils""" def get_logger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python command line logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ( "debug", "info", "warning", "error", "exception", "fatal", "critical", ): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger log = get_logger(__name__) def extras(config: DictConfig) -> None: """Applies optional utilities, controlled by config flags. Utilities: - Ignoring python warnings - Rich config printing """ # disable python warnings if if config.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # pretty print config tree using Rich library if if config.get("print_config"): log.info("Printing config tree with Rich! ") print_config(config, resolve=True) @rank_zero_only def print_config( config: DictConfig, print_order: Sequence[str] = ( "datamodule", "model", "callbacks", "logger", "trainer", ), resolve: bool = True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. print_order (Sequence[str], optional): Determines in what order config components are printed. resolve (bool, optional): Whether to resolve reference fields of DictConfig. """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) quee = [] for field in print_order: quee.append(field) if field in config else log.info( f"Field '{field}' not found in config" ) for field in config: if field not in quee: quee.append(field) for field in quee: branch = tree.add(field, style=style, guide_style=style) config_group = config[field] if isinstance(config_group, DictConfig): branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) else: branch_content = str(config_group) branch.add(rich.syntax.Syntax(branch_content, "yaml")) rich.print(tree) with open("config_tree.log", "w") as file: rich.print(tree, file=file) @rank_zero_only def log_hyperparameters( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: List[pl.loggers.LightningLoggerBase], ) -> None: """Controls which config parts are saved by Lightning loggers. Additionaly saves: - number of model parameters """ if not trainer.logger: return hparams = {} # choose which parts of hydra config will be saved to loggers hparams["model"] = config["model"] # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/trainable"] = sum( p.numel() for p in model.parameters() if p.requires_grad ) hparams["model/params/non_trainable"] = sum( p.numel() for p in model.parameters() if not p.requires_grad ) hparams["datamodule"] = config["datamodule"] hparams["trainer"] = config["trainer"] if "seed" in config: hparams["seed"] = config["seed"] if "callbacks" in config: hparams["callbacks"] = config["callbacks"] hparams["pacakges"] = get_packages_list() # send hparams to all loggers trainer.logger.log_hyperparams(hparams) def finish( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: List[pl.loggers.LightningLoggerBase], ) -> None: """Makes sure everything closed properly.""" # without this sweeps with wandb logger might crash! for lg in logger: if isinstance(lg, pl.loggers.wandb.WandbLogger): import wandb wandb.finish() def get_packages_list() -> List[str]: return [f"{p.project_name}=={p.version}" for p in pkg_resources.working_set] def retry_if_error(fn: Callable, num_attemps: int = 10): for attempt in range(num_attemps): try: return fn() except: print(f"Retrying, attempt {attempt+1}") pass return fn() class SavePytorchModelAndStopCallback(Callback): def __init__(self, path: str, attribute: Optional[str] = None): self.path = path self.attribute = attribute def on_train_start(self, trainer, pl_module): model, path = pl_module, self.path if self.attribute is not None: assert_message = "provided model attribute not found in pl_module" assert hasattr(pl_module, self.attribute), assert_message model = getattr( pl_module, self.attribute, hasattr(pl_module, self.attribute) ) # Make dir if not existent os.makedirs(os.path.split(path)[0], exist_ok=True) # Save model torch.save(model, path) log.info(f"PyTorch model saved at: {path}") # Stop trainer trainer.should_stop = True ================================================ FILE: saved_models/.gitkeep ================================================ ================================================ FILE: setup/environment.yml ================================================ name: tiny-audio-diffusion dependencies: - python=3.10 - pip - pip: - -r requirements.txt ================================================ FILE: setup/requirements.txt ================================================ torch>=2.0.1 torchaudio>=2.0.2 pytorch-lightning==1.7.7 torchmetrics==0.11.4 python-dotenv hydra-core hydra-colorlog wandb auraloss yt-dlp datasets pyloudnorm einops omegaconf rich plotly librosa transformers eng-to-ipa ema-pytorch py7zr notebook matplotlib ipykernel gradio # k-diffusion # v-diffusion-pytorch audio-diffusion-pytorch==0.1.3 audio-encoders-pytorch audio-data-pytorch quantizer-pytorch difformer-pytorch a-transformers-pytorch ================================================ FILE: train.py ================================================ import os import dotenv import hydra import pytorch_lightning as pl from main import utils from omegaconf import DictConfig, open_dict # import torch # use if direct checkpoint load required (see line 87) # Load environment variables from `.env`. dotenv.load_dotenv(override=True) log = utils.get_logger(__name__) @hydra.main(config_path=".", config_name="config.yaml", version_base=None) def main(config: DictConfig) -> None: # Logs config tree utils.extras(config) # Apply seed for reproducibility pl.seed_everything(config.seed) # Initialize datamodule log.info(f"Instantiating datamodule <{config.datamodule._target_}>.") datamodule = hydra.utils.instantiate(config.datamodule, _convert_="partial") # Initialize model log.info(f"Instantiating model <{config.model._target_}>.") model = hydra.utils.instantiate(config.model, _convert_="partial") # Initialize all callbacks (e.g. checkpoints, early stopping) callbacks = [] # If save is provided add callback that saves and stops, to be used with +ckpt if "save" in config: # Ignore loggers and other callbacks with open_dict(config): config.pop("loggers") config.pop("callbacks") config.trainer.num_sanity_val_steps = 0 attribute, path = config.get("save"), config.get("ckpt_dir") filename = os.path.join(path, f"{attribute}.pt") callbacks += [utils.SavePytorchModelAndStopCallback(filename, attribute)] if "callbacks" in config: for _, cb_conf in config["callbacks"].items(): if "_target_" in cb_conf: log.info(f"Instantiating callback <{cb_conf._target_}>.") callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial")) # Initialize loggers (e.g. wandb) loggers = [] if "loggers" in config: for _, lg_conf in config["loggers"].items(): if "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>.") # Sometimes wandb throws error if slow connection... logger = utils.retry_if_error( lambda: hydra.utils.instantiate(lg_conf, _convert_="partial") ) loggers.append(logger) # Initialize trainer log.info(f"Instantiating trainer <{config.trainer._target_}>.") trainer = hydra.utils.instantiate( config.trainer, callbacks=callbacks, logger=loggers, _convert_="partial" ) # Send some parameters from config to all lightning loggers log.info("Logging hyperparameters!") utils.log_hyperparameters( config=config, model=model, datamodule=datamodule, trainer=trainer, callbacks=callbacks, logger=loggers, ) # Train with checkpoint if present, otherwise from start if "ckpt" in config: ckpt = config.get("ckpt") log.info(f"Starting training from {ckpt}") trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt) # # Alternative model load method # # Use if loading from checkpoint with pl trainer causes GPU memory spike (CUDA out of memory). # checkpoint = torch.load(config.get("ckpt"), map_location='cpu')['state_dict'] # model.load_state_dict(checkpoint) # trainer.fit(model=model, datamodule=datamodule) else: log.info("Starting training.") trainer.fit(model=model, datamodule=datamodule) # Make sure everything closed properly log.info("Finalizing!") utils.finish( config=config, model=model, datamodule=datamodule, trainer=trainer, callbacks=callbacks, logger=loggers, ) # Print path to best checkpoint if ( not config.trainer.get("fast_dev_run") and config.get("train") and not config.get("save") ): log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") if __name__ == "__main__": main()