[
  {
    "path": ".gitattributes",
    "content": "# Python .gitattributes (modified from: https://github.com/alexkaratarakis/gitattributes)\n\n# Source files\n# ============\n*.pxd    text diff=python\n*.py     text diff=python\n*.py3    text diff=python\n*.pyw    text diff=python\n*.pyx    text diff=python\n*.pyz    text diff=python\n*.pyi    text diff=python\n\n# Binary files\n# ============\n*.db     binary\n*.p      binary\n*.pkl    binary\n*.pickle binary\n*.pyc    binary export-ignore\n*.pyo    binary export-ignore\n*.pyd    binary\n\n# Python files\n# ============\n*.py     linguist-language=Python\n\n# Jupyter Notebook\n# ============\n*.ipynb  linguist-language=Jupyter Notebook\n*.ipynb  text eol=lf"
  },
  {
    "path": ".gitignore",
    "content": "# Custom ignore\n__pycache__\n.mypy_cache\n.env\n.DS_Store\n.DS_Store/\n.hydra\nvenv/\nlogs/\n.vscode/\n*Zone.Identifier\nkicks/\nsnares/\nhihats/\nclaps/\nsnaps/\ncymbals/\nrides/\ntoms/\npercussion/\narchive/\nignore/\nvideo_samples/\n\n# Python Template\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n"
  },
  {
    "path": "Inference.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"1fc06181\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Inference Notebook\\n\",\n    \"- This notebook serves to import trained models and generate new samples.\\n\",\n    \"- You should only need to edits the [Checkpoint & Configs](#Checkpoint-\\\\&-Configs) and [Define Sample Parameters](#Define-Sample-Parameters) cells.\\n\",\n    \"- Currently this notebook only offers unconditional generation, but I plan to include more features in the future.\\n\",\n    \"- Have fun creating new sounds!\\n\",\n    \"\\n\",\n    \"*(NOTE: Don't do \\\"run all\\\" in Jupyter. For some reason it doesn't output anything when I used this option. It may work in other environments, but just a heads up!)*\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"7b67d821\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Imports\\n\",\n    \"Import necessary libraries to run the notebook\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3dc381cc\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Imports\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import torch\\n\",\n    \"import torchaudio\\n\",\n    \"from torch import nn\\n\",\n    \"import pytorch_lightning as pl\\n\",\n    \"from ema_pytorch import EMA\\n\",\n    \"import IPython.display as ipd\\n\",\n    \"import yaml\\n\",\n    \"from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler\\n\",\n    \"from diffusion import sampling, utils\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"49d620cb\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Checkpoint & Configs\\n\",\n    \"- Replace these paths with the path to your model's checkpoint and configs.\\n\",\n    \"- Pre-trained models are availlable to download on Hugging Face.\\n\",\n    \"\\n\",\n    \"|Model|Link|\\n\",\n    \"|---|---|\\n\",\n    \"|Kicks|[crlandsc/tiny-audio-diffusion-kicks](https://huggingface.co/crlandsc/tiny-audio-diffusion-kicks)|\\n\",\n    \"|Snares|[crlandsc/tiny-audio-diffusion-snares](https://huggingface.co/crlandsc/tiny-audio-diffusion-snares)|\\n\",\n    \"|Hi-hats|[crlandsc/tiny-audio-diffusion-hihats](https://huggingface.co/crlandsc/tiny-audio-diffusion-hihats)|\\n\",\n    \"|Percussion (all drum types)|[crlandsc/tiny-audio-diffusion-percussion](https://huggingface.co/crlandsc/tiny-audio-diffusion-percussion)|\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5037ead6\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Load model checkpoint\\n\",\n    \"ckpt_path = \\\"./saved_models/kicks/kicks_v7.ckpt\\\" # path to model checkpoint\\n\",\n    \"config_path = \\\"./saved_models/kicks/config.yaml\\\" # path to model config\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"f2f61804\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Functions & Models\\n\",\n    \"- Functions and models definitions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"46eec999\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Load configs\\n\",\n    \"with open(config_path, 'r') as file:\\n\",\n    \"    config = yaml.safe_load(file)\\n\",\n    \"pl_configs = config['model']\\n\",\n    \"model_configs = config['model']['model']\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f4797122\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def plot_mel_spectrogram(sample):\\n\",\n    \"    transform = torchaudio.transforms.MelSpectrogram(\\n\",\n    \"        sample_rate=sr,\\n\",\n    \"        n_fft=1024,\\n\",\n    \"        hop_length=512,\\n\",\n    \"        n_mels=80,\\n\",\n    \"        center=True,\\n\",\n    \"        norm=\\\"slaney\\\",\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram\\n\",\n    \"    spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)\\n\",\n    \"\\n\",\n    \"    # Plot the Mel spectrogram\\n\",\n    \"    fig = plt.figure(figsize=(7, 4))\\n\",\n    \"    plt.imshow(spectrogram, aspect='auto', origin='lower')\\n\",\n    \"    plt.colorbar(format='%+2.0f dB')\\n\",\n    \"    plt.xlabel('Frame')\\n\",\n    \"    plt.ylabel('Mel Bin')\\n\",\n    \"    plt.title('Mel Spectrogram')\\n\",\n    \"    plt.tight_layout()\\n\",\n    \"    \\n\",\n    \"    return fig\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e246c0e2\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Define PyTorch Lightning model\\n\",\n    \"class Model(pl.LightningModule):\\n\",\n    \"    def __init__(\\n\",\n    \"        self,\\n\",\n    \"        lr: float,\\n\",\n    \"        lr_beta1: float,\\n\",\n    \"        lr_beta2: float,\\n\",\n    \"        lr_eps: float,\\n\",\n    \"        lr_weight_decay: float,\\n\",\n    \"        ema_beta: float,\\n\",\n    \"        ema_power: float,\\n\",\n    \"        model: nn.Module,\\n\",\n    \"    ):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.lr = lr\\n\",\n    \"        self.lr_beta1 = lr_beta1\\n\",\n    \"        self.lr_beta2 = lr_beta2\\n\",\n    \"        self.lr_eps = lr_eps\\n\",\n    \"        self.lr_weight_decay = lr_weight_decay\\n\",\n    \"        self.model = model\\n\",\n    \"        self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"5b2aecab\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Instantiate model\\n\",\n    \"*NOTE: This model setup needs to exactly match the model that was trained*\\n\",\n    \"\\n\",\n    \"- This cell instantiates the model (no weights) using the config.yaml file. This is structure is critical to make sure that the model weights can be loaded in correctly.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d626c6e7\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Instantiate model (must match model that was trained)\\n\",\n    \"\\n\",\n    \"# Diffusion model\\n\",\n    \"model = DiffusionModel(\\n\",\n    \"    net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)\\n\",\n    \"    in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels\\n\",\n    \"    channels=model_configs['channels'], # U-Net: channels at each layer\\n\",\n    \"    factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer\\n\",\n    \"    items=model_configs['items'], # U-Net: number of repeating items at each layer\\n\",\n    \"    attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer\\n\",\n    \"    attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item\\n\",\n    \"    attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item\\n\",\n    \"    diffusion_t=VDiffusion, # The diffusion method used\\n\",\n    \"    sampler_t=VSampler # The diffusion sampler used\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# pl model\\n\",\n    \"model = Model(\\n\",\n    \"    lr=pl_configs['lr'],\\n\",\n    \"    lr_beta1=pl_configs['lr_beta1'],\\n\",\n    \"    lr_beta2=pl_configs['lr_beta2'],\\n\",\n    \"    lr_eps=pl_configs['lr_eps'],\\n\",\n    \"    lr_weight_decay=pl_configs['lr_weight_decay'],\\n\",\n    \"    ema_beta=pl_configs['ema_beta'],\\n\",\n    \"    ema_power=pl_configs['ema_power'],\\n\",\n    \"    model=model\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"c2d2f702\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Check if GPU available\\n\",\n    \"- This checks to see if a CUDe capable GPU is available to utilize. If so, the model is assigned to the GPU. If not, the model simply remains on the CPU.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"9ce84487\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Assign to GPU\\n\",\n    \"if torch.cuda.is_available():\\n\",\n    \"    model = model.to('cuda')\\n\",\n    \"    print(f\\\"Device: {model.device}\\\")\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"825d96d9\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Load model\\n\",\n    \"- This cell loads the checkpoint weights into the model. It should return `\\\"<All keys matched successfully>\\\"` if successfully loaded.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"845993e4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Load model checkpoint\\n\",\n    \"checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict']\\n\",\n    \"model.load_state_dict(checkpoint) # should output \\\"<All keys matched successfully>\\\"\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"aaaa996f\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Unconditional Sample Generation\\n\",\n    \"Generate new sounds from noise with no conditioning.\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"c93dc280\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define Sample Parameters\\n\",\n    \"- sample_length: how long to make the output (measured in samples). Recommended $2^{15}=32768$ (~0.75 sec), as that is what the model was trained on.\\n\",\n    \"- sr (sample rate): sampling rate to output. Recommended industry standard 44.1kHz (44100Hz).\\n\",\n    \"- num_samples: number of new samples that will be generated.\\n\",\n    \"- num_steps: number of diffusion steps - tradeoff inference speed for sample quality (10-100 is a good range).\\n\",\n    \"    - 10+ steps - quick generation, alright samples but noticeable high-freq hiss.\\n\",\n    \"    - 50+ steps - moderate generation speed, good tradeoff for speed and qualiy (less high-freq hiss).\\n\",\n    \"    - 100+ steps - slow generation speed, high quality samples.\\n\",\n    \"\\n\",\n    \"Have fun playing around with these parameters! Note that sometimes the model outputs some wild things. This is likely due to the small size of the model as well as the limited training data. Larger models and/or larger and more diverse datasets would improve this.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6f0438d4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Define diffusion paramters\\n\",\n    \"sample_length = 2**15 # 32768 samples @ 44100 = .75 sec\\n\",\n    \"sr = 44100\\n\",\n    \"num_samples = 3 # number of samples to generate\\n\",\n    \"num_steps = 50 # number of diffusion steps, tradeoff inference speed for sample quality (10-100 is a good range)\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"4952b429\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Generate samples\\n\",\n    \"Run the following cell to generate samples based on previously defined parameters\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"bf88849c\",\n   \"metadata\": {\n    \"scrolled\": false\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"with torch.no_grad():\\n\",\n    \"    all_samples = torch.zeros(2, 0)\\n\",\n    \"    for i in range(num_samples):\\n\",\n    \"        noise = torch.randn((1, 2, sample_length), device=model.device) # [batch_size, in_channels, length]\\n\",\n    \"        generated_sample = model.model_ema.ema_model.sample(noise, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100\\n\",\n    \"\\n\",\n    \"        print(f\\\"Generated Sample {i+1}\\\")\\n\",\n    \"        display(ipd.Audio(generated_sample, rate=sr))\\n\",\n    \"        \\n\",\n    \"        # concatenate all samples:\\n\",\n    \"        all_samples = torch.concat((all_samples, generated_sample), dim=1)\\n\",\n    \"        \\n\",\n    \"        fig = plot_mel_spectrogram(generated_sample)\\n\",\n    \"        plt.title(f\\\"Mel Spectrogram (Sample {i+1})\\\")\\n\",\n    \"        plt.show()\\n\",\n    \"        \\n\",\n    \"        torch.cuda.empty_cache()\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"c18bda2c\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Combine all samples\\n\",\n    \"- Option to combine all samples into a single sample\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"f8003140\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Optional concatenate all samples\\n\",\n    \"print(f\\\"All Samples\\\")\\n\",\n    \"display(ipd.Audio(all_samples, rate=sr))\\n\",\n    \"fig = plot_mel_spectrogram(all_samples)\\n\",\n    \"plt.title(f\\\"Mel Spectrogram)\\\")\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"9b5cc6ca\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Conditional \\\"Style-Transfer\\\" Generation\\n\",\n    \"Generate new sounds conditioned on input audio.\\n\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"f122e399\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Define Sample Parameters\\n\",\n    \"- audio_file_path: Path to audio file for conditioning the model.\\n\",\n    \"- sample_with_noise: Option to output the conditioning sample with noise added to listen, or suppress it.\\n\",\n    \"- trim_sample: Option to trim/pad sample if it is too long/short.\\n\",\n    \"- sample_length: how long to make the output (measured in samples). Recommended $2^{15}=32768$ (~0.75 sec), as that is what the model was trained on.\\n\",\n    \"- sr (sample rate): sampling rate to output. Recommended industry standard 44.1kHz (44100Hz).\\n\",\n    \"- num_samples: number of new samples that will be generated.\\n\",\n    \"- noise_level: The amount of noise to be added to the input sample.\\n\",\n    \"- num_steps: number of diffusion steps - tradeoff inference speed for sample quality.\\n\",\n    \"  - The number of steps for conditional diffusion varies more compared to unconditional diffusion. For example, if you input a transient sound (like a snare hit) and want to transfer it to the `kicks` model, then you may not want to add any noise and keep the steps below 10 for an interesting sound. But, if you want to transfer something like a guitar to the percussion model, you may want to add some more noise and increase the number of steps.\\n\",\n    \"\\n\",\n    \"*NOTE:* The less noise that is added to a sample, the less varied the outputs will be. For example, if you ad 0 noise to a sample and generate it 3 times, it will produce the exact same thing 3 times (because the input remains consistent). As you increase the noise added, increasing the variation of the inputs, the outputs will vary more widely as well.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"38a62339\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Define diffusion paramters\\n\",\n    \"audio_file_path = \\\"samples/snare1.wav\\\"\\n\",\n    \"\\n\",\n    \"# Listen to noised sample\\n\",\n    \"sample_with_noise = False # True to listen to sample + noise, false to not output\\n\",\n    \"\\n\",\n    \"# If sample too long\\n\",\n    \"trim_sample = False # True - if sample too long / False does not trim\\n\",\n    \"sample_length = 2**15 # NA\\n\",\n    \"\\n\",\n    \"sr = 44100 # Sampling rate\\n\",\n    \"num_samples = 1 # number of samples to generate\\n\",\n    \"noise_level = 0 # between 0 and 1\\n\",\n    \"num_steps = 6 # number of diffusion steps, tradeoff inference speed for sample quality (10-100 is a good range)\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"15298646\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Generate samples\\n\",\n    \"Run the following cell to generate samples based on previously defined parameters\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5bf5203d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Generate samples\\n\",\n    \"with torch.no_grad():\\n\",\n    \"\\n\",\n    \"    # load audio sample\\n\",\n    \"    audio_sample = torchaudio.load(audio_file_path)[0].unsqueeze(0).to(model.device) # unsqueeze for correct tensor shape\\n\",\n    \"\\n\",\n    \"    # Trim audio\\n\",\n    \"    if trim_sample:\\n\",\n    \"        og_shape = audio_sample.shape\\n\",\n    \"        if sample_length < og_shape[2]:\\n\",\n    \"            audio_sample = audio_sample[:,:,:sample_length]\\n\",\n    \"        elif sample_length > og_shape[2]:\\n\",\n    \"            # Pad tensor with zeros to match sample length\\n\",\n    \"            audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], sample_length - og_shape[2]).to(model.device)), dim=2)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    original_audio = audio_sample.squeeze(0).squeeze(0).cpu()\\n\",\n    \"\\n\",\n    \"    # Display original audio sample\\n\",\n    \"    print(f\\\"Original Sample\\\")\\n\",\n    \"    display(ipd.Audio(original_audio, rate=sr))\\n\",\n    \"\\n\",\n    \"    # Plot original audio\\n\",\n    \"    fig = plot_mel_spectrogram(original_audio)\\n\",\n    \"    plt.title(f\\\"Mel Spectrogram (Original Sample)\\\")\\n\",\n    \"    plt.show()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    # Display original audio sample + noise\\n\",\n    \"    if sample_with_noise:\\n\",\n    \"        noise = torch.randn_like(audio_sample, device=model.device) * noise_level # combine input signal and noise\\n\",\n    \"        noised_sample = (audio_sample + noise).squeeze(0).cpu() # normalize?\\n\",\n    \"        print(f\\\"Original Noised Sample\\\")\\n\",\n    \"        display(ipd.Audio(noised_sample, rate=sr))\\n\",\n    \"\\n\",\n    \"        # Plot original audio + noise\\n\",\n    \"        fig = plot_mel_spectrogram(noised_sample)\\n\",\n    \"        plt.title(f\\\"Mel Spectrogram (Noised Sample)\\\")\\n\",\n    \"        plt.show()\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    all_samples = torch.zeros(2, 0)\\n\",\n    \"    for i in range(num_samples):\\n\",\n    \"        noise = torch.randn_like(audio_sample, device=model.device) * noise_level # combine input signal and noise\\n\",\n    \"        audio = audio_sample + noise # normalize?\\n\",\n    \"        generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu()\\n\",\n    \"\\n\",\n    \"        print(f\\\"Generated Sample {i+1}\\\")\\n\",\n    \"        display(ipd.Audio(generated_sample, rate=sr))\\n\",\n    \"        \\n\",\n    \"        # concatenate all samples:\\n\",\n    \"        all_samples = torch.concat((all_samples, generated_sample), dim=1)\\n\",\n    \"        \\n\",\n    \"        fig = plot_mel_spectrogram(generated_sample)\\n\",\n    \"        plt.title(f\\\"Mel Spectrogram (Sample {i+1})\\\")\\n\",\n    \"        plt.show()\\n\",\n    \"        \\n\",\n    \"        torch.cuda.empty_cache()\"\n   ]\n  },\n  {\n   \"attachments\": {},\n   \"cell_type\": \"markdown\",\n   \"id\": \"4762cc6b\",\n   \"metadata\": {},\n   \"source\": [\n    \"#### Combine all samples\\n\",\n    \"- Option to combine all samples into a single sample\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"5501ce6a\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Optional concatenate all samples\\n\",\n    \"print(f\\\"All Samples\\\")\\n\",\n    \"display(ipd.Audio(all_samples, rate=sr))\\n\",\n    \"fig = plot_mel_spectrogram(all_samples)\\n\",\n    \"plt.title(f\\\"Mel Spectrogram)\\\")\\n\",\n    \"plt.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"0ef83b4b\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# TODO: Add normalization?\\n\",\n    \"# TODO: Add other smapling methods (currently only DDIM)\\n\",\n    \"# TODO: clean cell (make functions)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"tiny-audio-diffusion (Python 3.10)\",\n   \"language\": \"python\",\n   \"name\": \"tiny-audio-diffusion\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.11\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Christopher Landschoot\nCopyright (c) 2022 archinet.ai\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n  <h1 style=\"font-size: 36px;\">Tiny Audio Diffusion</h1>\n  <img src=\"./images/tiny-audio-diffusion.png\" width=\"250px\" alt=\"Tiny Audio Diffusion Logo\" />\n</div>\n<br>\n\n[![Hugging Face Spaces Badge](https://img.shields.io/badge/%F0%9F%A4%97_Spaces_Demo-blue)](https://huggingface.co/spaces/crlandsc/tiny-audio-diffusion) [![YouTube Tutorial Badge](https://img.shields.io/badge/Repo_Tutorial-red?logo=YouTube)](https://youtu.be/m6Eh2srtTro) [![Towards Data Science Badge](https://img.shields.io/badge/Towards_Data_Science-red?logo=Medium&color=black)](https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b) [![GitHub License](https://img.shields.io/github/license/crlandsc/tiny-audio-diffusion)](https://github.com/crlandsc/tiny-audio-diffusion/blob/main/LICENSE) [![GitHub Repo stars](https://img.shields.io/github/stars/crlandsc/tiny-audio-diffusion?color=gold)](https://github.com/crlandsc/tiny-audio-diffusion/stargazers) [![GitHub forks](https://img.shields.io/github/forks/crlandsc/tiny-audio-diffusion?color=green)](https://github.com/crlandsc/tiny-audio-diffusion/forks)\n\nThis is a repository for generating short audio samples and training waveform diffusion models on a consumer-grade GPU with less than 2GB VRAM.\n\n## Motivation\n\nThe purpose of this project is to provide access to stereo high-resolution (44.1kHz) conditional and unconditional audio waveform (1D U-Net) diffusion code for those interested in exploration but who have limited resources. There are many methods for audio generation on low-level hardware, but less so specifically for waveform-based diffusion.\n\nThe repository is built heavily adapting code from Archinet's [audio-diffusion-pytorch](https://github.com/archinetai/audio-diffusion-pytorch) libary. A huge thank you to [Flavio Schneider](https://github.com/flavioschneider) for his incredible open-source work in this field!\n\n\n## Background\n\nDirect waveform diffusion is inherently computationally intensive. For example, an audio sample with the industry standard 44.1kHz sampling rate requires 44,100 samples for just 1 second of audio. Now multiply that by 2 for a stereo file. However, it has a significant advantage over many methods that reduce audio into spectrograms or downsample - the network retains and learns from *phase* information. Phase is challenging to represent on its own in visual methods, such as spectrograms, as it appears similar to that of random noise. Because of this, many generative methods discard phase information and then implement ways of estimating and regenerating it. However, it plays a key role in defining the timbral qualities of sounds and should not be dispensed with so easily.\n\nWaveform diffusion is able to retain this important feature as it does not perform any transforms on the audio before feeding it into the network. This is how humans perceive sounds, with both amplitude and phase information bundled together in a single signal. As mentioned previously, this comes at the expense of computational requirements and is often reserved for training on a cluster of GPUs with high speeds and lots of memory. Because of this, it is hard to begin to experiment with waveform diffusion with limited resources.\n\nThis repository seeks to offer some base code to those looking to experiment with and learn more about waveform diffusion on their own computer without having to purchase cloud resources or upgrade hardware. This goes for not only *inference*, but *training* your own models as well!\n\nTo make this feasible, however, there must be a tradeoff of quality, speed, and sample length. Because of this, I have focused on training base models for one-shot drum samples - as they are inherently short in sample length.\n\nThe current configuration is set up to be able to train ~0.75 second stereo samples at 44.1kHz, allowing for the generation of high-quality one-shot audio samples. The network configuration can be adjusted to improve the resolution, sample rate, training and inference speed, sample length, etc. but, of course, more hardware resources will be required.\n\nOther methods of diffusion, such as diffusion in the latent space ([Stable Diffusion's](https://stability.ai/stablediffusion) secret sauce), compared to this repo's raw waveform diffusion can offer an improvement and other tradeoffs between quality, memory requirements, speed, etc. I recommend this repo to remain up-to-date with the latest research in generative audio: https://github.com/archinetai/audio-ai-timeline\n\nAlso recommended is [Harmonai's](https://www.harmonai.org/) community project, [Dance Diffusion](https://github.com/Harmonai-org/sample-generator), which implements similar functionality to this repo on a larger scale with several pre-trained models. [Colab notebook](https://colab.research.google.com/github/Harmonai-org/sample-generator/blob/main/Dance_Diffusion.ipynb) available.\n\n**April 2024 update:**\n\nSome additional useful generative audio tools/repos:\n- [Stable Audio Tools](https://github.com/Stability-AI/stable-audio-tools) (used in [Stable Audio](https://www.stableaudio.com/)) - Useful audio tools for building and training models.\n- [audiocraft](https://github.com/facebookresearch/audiocraft) (used in [MusicGen](https://audiocraft.metademolab.com/musicgen.html) & [AudioGen](https://audiocraft.metademolab.com/audiogen.html)) - Useful audio tools for building and training models.\n- [audiomentations](https://github.com/iver56/audiomentations) - Good library for implementing audio augmentations on CPU for training. See [torch-audiomentations](https://github.com/asteroid-team/torch-audiomentations) for GPU implementation.\n\n---\n\n## Setup\n\nFollow these steps to set up an environment for both generating audio samples and training models.\n\n*NOTE:* To use this repo with a GPU, you must have a CUDA-capable GPU and have the CUDA toolkit installed for your specific to your system (ex. Linux, x86_64, WSL-Ubuntu). More information can be found [here](https://developer.nvidia.com/cuda-toolkit).\n\n#### 1. Create a Virtual Environment:\n\nEnsure that [Anaconda (or Miniconda)](https://docs.anaconda.com/free/anaconda/install/index.html) is installed and activated. From the command line, `cd` into the [`setup/`](setup/) folder and run the following lines:\n```bash\nconda env create -f environment.yml\nconda activate tiny-audio-diffusion\n```\n\nThis will create and activate a conda environment from the [`setup/environment.yml`](setup/environment.yml) file and install the dependencies in [`setup/requirements.txt`](setup/requirements.txt).\n\n#### 2. Install Python Kernel For Jupyter Notebook\n\nRun the following line to create a kernel for the current environment to run the inference notebook.\n\n```bash\npython -m ipykernel install --user --name tiny-audio-diffusion --display-name \"tiny-audio-diffusion (Python 3.10)\"\n```\n\n#### 3. Define Environment Variables\n\nRename [`.env.tmp`](.env.tmp) to `.env` and replace the entries with your own variables (example values are random).\n\n```bash\nDIR_LOGS=/logs\nDIR_DATA=/data\n\n# Required if using Weights & Biases (W&B) logger\nWANDB_PROJECT=tiny_drum_diffusion # Custom W&B name for current project\nWANDB_ENTITY=johnsmith # W&B username\nWANDB_API_KEY=a21dzbqlybbzccqla4txa21dzbqlybbzccqla4tx # W&B API key\n```\n\n*NOTE:* Sign up for a [Weights & Biases](https://wandb.ai/site) account to log audio samples, spectrograms, and other metrics while training (it's free!).\n\nW&B logging example for this repo [here](https://wandb.ai/crlandsc/unconditional-drum-diffusion?workspace=user-crlandsc).\n\n---\n\n## Pre-trained Models\n\nPretrained models can be found on Hugging Face (each model contains a `.ckpt` and `.yaml` file):\n\n|Model|Link|\n|---|---|\n|Kicks|[crlandsc/tiny-audio-diffusion-kicks](https://huggingface.co/crlandsc/tiny-audio-diffusion-kicks)|\n|Snares|[crlandsc/tiny-audio-diffusion-snares](https://huggingface.co/crlandsc/tiny-audio-diffusion-snares)|\n|Hi-hats|[crlandsc/tiny-audio-diffusion-hihats](https://huggingface.co/crlandsc/tiny-audio-diffusion-hihats)|\n|Percussion (all drum types)|[crlandsc/tiny-audio-diffusion-percussion](https://huggingface.co/crlandsc/tiny-audio-diffusion-percussion)|\n\n*See W&B model training metrics [here](https://wandb.ai/crlandsc/unconditional-drum-diffusion?workspace=user-crlandsc).*\n\nPre-trained models can be downloaded to generate samples via the [inference notebook](Inference.ipynb). They can also be used as a base model to fine-tune on custom data. It is recommended to create subfolders within the [`saved_models`](saved_models/) folder to store each model's `.ckpt` and `.yaml` files.\n\n---\n\n## Inference\n### Hugging Face Spaces\nGenerate samples without code on [🤗 Hugging Face Spaces](https://huggingface.co/spaces/crlandsc/tiny-audio-diffusion)!\n\n### Jupyter Notebook\n#### Audio Sample Generation\nCurrent Capabilities:\n- Unconditional Generation\n- Conditional \"Style-transfer\" Generation\n\nOpen the [`Inference.ipynb`](Inference.ipynb) in Jupyter Notebook and follow the instructions to generate new audio samples. Ensure that the `\"tiny-audio-diffusion (Python 3.10)\"` kernel is active in Jupyter to run the notebook and you have downloaded the [pre-trained model](#Pre\\-trained-Models) of interest from Hugging Face.\n\n---\n\n## Train\n\nThe model architecture has been constructed with [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/) and [Hydra](https://hydra.cc/docs/intro/) frameworks. All configurations for the model are contained within `.yaml` files and should be edited there rather than hardcoded.\n\n[`exp/drum_diffusion.yaml`](exp/drum_diffusion.yaml) contains the default model configuration. Additional custom model configurations can be added to the [`exp`](exp/) folder.\n\nCustom models can be trained or fine-tuned on custom datasets. Datasets should consist of a folder of `.wav` audio files with a 44.1kHz sampling rate.\n\nTo train or finetune models, run one of the following commands in the terminal from the repo's root folder and replace `<path/to/your/train/data>` with the path to your custom training set.\n\n\n**Train model from scratch (on CPU):**\n*(not recommended)*\n\n```bash\npython train.py exp=drum_diffusion datamodule.dataset.path=<path/to/your/train/data>\n```\n\n\n**Train model from scratch (on GPU):**\n\n```bash\npython train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path=<path/to/your/train/data>\n```\n\n*NOTE:* To use this repo with a GPU, you must have a CUDA-capable GPU and have the CUDA toolkit installed specific to your system (ex. Linux, x86_64, WSL-Ubuntu). More information can be found [here](https://developer.nvidia.com/cuda-toolkit).\n\n\n**Resume run from a checkpoint (with GPU):**\n\n```bash\npython train.py exp=drum_diffusion trainer.gpus=1 +ckpt=</path/to/checkpoint.ckpt> datamodule.dataset.path=<path/to/your/train/data>\n```\n\n---\n\n## Dataset\n\nThe data used to train the checkpoints listed above can be found on [🤗 Hugging Face](https://huggingface.co/datasets/crlandsc/tiny-audio-diffusion-drums).\n\n***Note:*** *This is a small and unbalanced dataset consisting of free samples that I had from my music production. These samples are not covered under the MIT license of this repository and cannot be used to train any commercial models, but can be used in personal and research contexts.*\n\n***Note:*** *For appropriately diverse models, larger datasets should be used to avoid memorization of training data.*\n\n---\n\n## Repository Structure\n\nThe structure of this repository is as follows:\n```\n├── main\n│   ├── diffusion_module.py     - contains pl model, data loading, and logging functionalities for training\n│   └── utils.py                - contains utility functions for training\n├── exp\n│   └── *.yaml                  - Hydra configuration files\n├── setup\n│   ├── environment.yml         - file to set up conda environment\n│   └── requirements.txt        - contains repo dependencies\n├── images                      - directory containing images for README.md\n│   └── *.png\n├── samples                     - directory containing sample outputs from tiny-audio-diffusion models\n│   └── *.wav\n├── .env.tmp                    - temporary environment variables (rename to .env)\n├── .gitignore\n├── README.md\n├── Inference.ipynb             - Jupyter notebook for running inference to generate new samples\n├── config.yaml                 - Hydra base configs\n├── train.py                    - script for training\n├── data                        - directory to host custom training data\n│   └── wav_dataset\n│       └── (*.wav)\n└── saved_models                - directory to host model checkpoints and hyper-parameters for inference\n    └── (kicks/snare/etc.)\n        ├── (*.ckpt)            - pl model checkpoint file\n        └── (config.yaml)       - pl model hydra hyperparameters (required for inference)\n```\n"
  },
  {
    "path": "config.yaml",
    "content": "defaults:\n  - _self_\n  - exp: null # config to load\n  - override hydra/hydra_logging: colorlog\n  - override hydra/job_logging: colorlog\n\nseed: 12345\ntrain: True\nignore_warnings: True\nprint_config: False # Prints tree with all configurations\nwork_dir: ${hydra:runtime.cwd}  # This is the root of the project\nlogs_dir: ${work_dir}${oc.env:DIR_LOGS}  # This is the root for all logs\ndata_dir: ${work_dir}${oc.env:DIR_DATA} # This is the root for all data\nckpt_dir: ${logs_dir}/runs/${now:%Y-%m-%d-%H-%M-%S}\n\n# Hydra experiment configs log dir\nhydra:\n  run:\n    dir: ${ckpt_dir} # save in same dir as ckpts\n"
  },
  {
    "path": "data/wav_dataset/.gitkeep",
    "content": ""
  },
  {
    "path": "exp/drum_diffusion.yaml",
    "content": "# @package _global_\n\n# Unconditional Audio Waveform Diffusion\n\n# To execute this experiment on a single GPU, run:\n# python train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path=<path/to/your/train/data>\n\nmodule: main.diffusion_module\nbatch_size: 1 # mini-batch size (increase to speed up at the cost of memory)\naccumulate_grad_batches: 32 # use to increase batch size on single GPU -> effective batch size = (batch_size * accumulate_grad_batches)\nnum_workers: 8 # num workers for data loading\n\nsampling_rate: 44100 # sampling rate (44.1kHz is the music industry standard)\nlength: 32768 # Length of audio in samples (32768 samples @ 44.1kHz ~ 0.75 seconds)\nchannels: 2 # stereo audio\nval_log_every_n_steps: 1000 # Logging interval (Validation and audio generation every n steps)\n# ckpt_every_n_steps: 4000 # Use if multiple checkpoints wanted\n\nmodel:\n  _target_: ${module}.Model # pl model wrapper\n  lr: 1e-4 # optimizer learning rate\n  lr_beta1: 0.95 # beta1 param for Adam optimizer\n  lr_beta2: 0.999 # beta2 param for Adam optimzer\n  lr_eps: 1e-6 # epsilon for optimizer (to avoid div by 0)\n  lr_weight_decay: 1e-3 # weight decay regularization param\n  ema_beta: 0.995 # EMA model (exponential-moving-average) beta\n  ema_power: 0.7 # EMA model gradiaent norm param\n\n  model:\n    _target_: audio_diffusion_pytorch.DiffusionModel # Waveform diffusion model\n    net_t:\n      _target_: ${module}.UNetT # The model type used for diffusion (U-Net V0 in this case)\n    in_channels: 2 # U-Net: number of input/output (audio) channels\n    channels: [32, 32, 64, 64, 128, 128, 256, 256] # U-Net: channels at each layer\n    factors: [1, 2, 2, 2, 2, 2, 2, 2] # U-Net: downsampling and upsampling factors at each layer\n    items: [2, 2, 2, 2, 2, 2, 4, 4] # U-Net: number of repeating items at each layer\n    attentions: [0, 0, 0, 0, 0, 1, 1, 1] # U-Net: attention enabled/disabled at each layer\n    attention_heads: 8 # U-Net: number of attention heads per attention item\n    attention_features: 64 # U-Net: number of attention features per attention item    \n\n# To specify train-valid datasets, datamodule must be reconfigured\ndatamodule:\n  _target_: main.diffusion_module.Datamodule\n  dataset:\n    _target_: audio_data_pytorch.WAVDataset\n    path: ./data/wav_dataset # can overried when calling train.py\n    recursive: True\n    sample_rate: ${sampling_rate}\n    transforms:\n      _target_: audio_data_pytorch.AllTransform\n      crop_size: ${length} # One-shots, so no random crop\n      stereo: True\n      source_rate: ${sampling_rate}\n      target_rate: ${sampling_rate}\n      loudness: -20 # normalize loudness\n  val_split: 0.1 # split data into validation\n  batch_size: ${batch_size}\n  num_workers: ${num_workers}\n  pin_memory: True\n\n\ncallbacks:\n  rich_progress_bar:\n    _target_: pytorch_lightning.callbacks.RichProgressBar\n    # _target_: pytorch_lightning.callbacks.TQDMProgressBar # use if RichProgressBar creates issues\n\n  model_checkpoint:\n    _target_: pytorch_lightning.callbacks.ModelCheckpoint\n    monitor: \"valid_loss\"   # name of the logged metric which determines when model is improving\n    save_top_k: 1           # save k best models (determined by above metric)\n    save_last: True         # additionaly always save model from last epoch\n    mode: \"min\"             # can be \"max\" or \"min\"\n    verbose: False\n    dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}\n    filename: '{epoch:02d}-{valid_loss:.3f}'\n    # every_n_train_steps: ${ckpt_every_n_steps} # Use if multiple checkpoints wanted\n\n  model_summary:\n    _target_: pytorch_lightning.callbacks.RichModelSummary\n    max_depth: 2\n\n  audio_samples_logger:\n    _target_: main.diffusion_module.SampleLogger\n    num_items: 4 # number of separate samples to be generated\n    channels: ${channels} # number of audio channels\n    sampling_rate: ${sampling_rate} # audio sampling rate\n    length: ${length} # length of generated sample\n    sampling_steps: [50] # number of steps per sample\n    use_ema_model: True # Use EMA for logger inference\n\nloggers:\n  wandb:\n    _target_: pytorch_lightning.loggers.wandb.WandbLogger\n    project: ${oc.env:WANDB_PROJECT} # defined in env var\n    entity: ${oc.env:WANDB_ENTITY} # defined in env var\n    name: unconditional_diffusion # name of run\n    # offline: False  # set True to store all logs only locally\n    job_type: \"train\"\n    group: \"\" # Set a group name if desired\n    save_dir: ${logs_dir}\n\ntrainer:\n  _target_: pytorch_lightning.Trainer\n  gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`\n  precision: 16 # Precision used for tensors (`32` offers higher precision, but `16` is used to save memory)\n  min_epochs: 0 # minimum number of epochs\n  max_epochs: -1 # max number of epochs (-1 = infinite run)\n  enable_model_summary: False\n  log_every_n_steps: 1 # Logs training metrics every n steps\n  # limit_val_batches: 10 # Use to limit the number of valid batches run (e.g. 10 stops training at 10 batches)\n  check_val_every_n_epoch: null\n  val_check_interval: ${val_log_every_n_steps} # Validation interval (check valid set and generate audio every n steps)\n  accumulate_grad_batches: ${accumulate_grad_batches} # use to increase batch size on single GPU"
  },
  {
    "path": "exp/drum_diffusion_no_wandb.yaml",
    "content": "# @package _global_\n\n# Unconditional Audio Waveform Diffusion\n\n# To execute this experiment on a single GPU, run:\n# python train.py exp=drum_diffusion trainer.gpus=1 datamodule.dataset.path=<path/to/your/train/data>\n\nmodule: main.diffusion_module\nbatch_size: 1 # mini-batch size (increase to speed up at the cost of memory)\naccumulate_grad_batches: 32 # use to increase batch size on single GPU -> effective batch size = (batch_size * accumulate_grad_batches)\nnum_workers: 8 # num workers for data loading\n\nsampling_rate: 44100 # sampling rate (44.1kHz is the music industry standard)\nlength: 32768 # Length of audio in samples (32768 samples @ 44.1kHz ~ 0.75 seconds)\nchannels: 2 # stereo audio\nval_log_every_n_steps: 1000 # Logging interval (Validation and audio generation every n steps)\n# ckpt_every_n_steps: 4000 # Use if multiple checkpoints wanted\n\nmodel:\n  _target_: ${module}.Model # pl model wrapper\n  lr: 1e-4 # optimizer learning rate\n  lr_beta1: 0.95 # beta1 param for Adam optimizer\n  lr_beta2: 0.999 # beta2 param for Adam optimzer\n  lr_eps: 1e-6 # epsilon for optimizer (to avoid div by 0)\n  lr_weight_decay: 1e-3 # weight decay regularization param\n  ema_beta: 0.995 # EMA model (exponential-moving-average) beta\n  ema_power: 0.7 # EMA model gradiaent norm param\n\n  model:\n    _target_: audio_diffusion_pytorch.DiffusionModel # Waveform diffusion model\n    net_t:\n      _target_: ${module}.UNetT # The model type used for diffusion (U-Net V0 in this case)\n    in_channels: 2 # U-Net: number of input/output (audio) channels\n    channels: [32, 32, 64, 64, 128, 128, 256, 256] # U-Net: channels at each layer\n    factors: [1, 2, 2, 2, 2, 2, 2, 2] # U-Net: downsampling and upsampling factors at each layer\n    items: [2, 2, 2, 2, 2, 2, 4, 4] # U-Net: number of repeating items at each layer\n    attentions: [0, 0, 0, 0, 0, 1, 1, 1] # U-Net: attention enabled/disabled at each layer\n    attention_heads: 8 # U-Net: number of attention heads per attention item\n    attention_features: 64 # U-Net: number of attention features per attention item    \n\n# To specify train-valid datasets, datamodule must be reconfigured\ndatamodule:\n  _target_: main.diffusion_module.Datamodule\n  dataset:\n    _target_: audio_data_pytorch.WAVDataset\n    path: ./data/wav_dataset # can overried when calling train.py\n    recursive: True\n    sample_rate: ${sampling_rate}\n    transforms:\n      _target_: audio_data_pytorch.AllTransform\n      crop_size: ${length} # One-shots, so no random crop\n      stereo: True\n      source_rate: ${sampling_rate}\n      target_rate: ${sampling_rate}\n      loudness: -20 # normalize loudness\n  val_split: 0.1 # split data into validation\n  batch_size: ${batch_size}\n  num_workers: ${num_workers}\n  pin_memory: True\n\n\ncallbacks:\n  rich_progress_bar:\n    _target_: pytorch_lightning.callbacks.RichProgressBar\n    # _target_: pytorch_lightning.callbacks.TQDMProgressBar # use if RichProgressBar creates issues\n\n  model_checkpoint:\n    _target_: pytorch_lightning.callbacks.ModelCheckpoint\n    monitor: \"valid_loss\"   # name of the logged metric which determines when model is improving\n    save_top_k: 1           # save k best models (determined by above metric)\n    save_last: True         # additionaly always save model from last epoch\n    mode: \"min\"             # can be \"max\" or \"min\"\n    verbose: False\n    dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}\n    filename: '{epoch:02d}-{valid_loss:.3f}'\n    # every_n_train_steps: ${ckpt_every_n_steps} # Use if multiple checkpoints wanted\n\n  model_summary:\n    _target_: pytorch_lightning.callbacks.RichModelSummary\n    max_depth: 2\n\ntrainer:\n  _target_: pytorch_lightning.Trainer\n  gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`\n  precision: 16 # Precision used for tensors (`32` offers higher precision, but `16` is used to save memory)\n  min_epochs: 0 # minimum number of epochs\n  max_epochs: -1 # max number of epochs (-1 = infinite run)\n  enable_model_summary: False\n  log_every_n_steps: 1 # Logs training metrics every n steps\n  # limit_val_batches: 10 # Use to limit the number of valid batches run (e.g. 10 stops training at 10 batches)\n  check_val_every_n_epoch: null\n  val_check_interval: ${val_log_every_n_steps} # Validation interval (check valid set and generate audio every n steps)\n  accumulate_grad_batches: ${accumulate_grad_batches} # use to increase batch size on single GPU"
  },
  {
    "path": "main/diffusion_module.py",
    "content": "# This code has been adapted from Flavio Schneider's work with Archinet.\n# (https://github.com/archinetai/audio-diffusion-pytorch-trainer)\n\nfrom audio_data_pytorch.utils import fractional_random_split\nfrom pytorch_lightning.loggers import LoggerCollection, WandbLogger\nfrom audio_diffusion_pytorch import UNetV0, VDiffusion, VSampler, LTPlugin\n\nimport random\nfrom typing import Any, List, Optional\n\nimport plotly.graph_objs as go\nimport pytorch_lightning as pl\nimport torch\nimport torchaudio\nimport wandb\n\nfrom einops import rearrange\nfrom ema_pytorch import EMA\nfrom pytorch_lightning import Callback, Trainer\nfrom torch import Tensor, nn\nfrom torch.utils.data import DataLoader\n\n\n\"\"\" Model \"\"\"\n\n# Option to use learned transform to downsample (by stride length) input data (not recommended).\n# Can reduce computational load, but introduces undesirable high freq artifacts.\nUNetT_LT = lambda: LTPlugin(UNetV0, num_filters=32, window_length=16, stride=16)\n\nUNetT = lambda: UNetV0 # define Unet to be used (from audio_diffusion_pytorch)\nDiffusionT = VDiffusion # define diffusion method to be used (from audio_diffusion_pytorch)\nSamplerT = VSampler # define diffusion sampler to be used (from audio_diffusion_pytorch)\n\ndef dropout(proba: float):\n    return random.random() < proba\n\nclass Model(pl.LightningModule):\n    def __init__(\n        self,\n        lr: float,\n        lr_beta1: float,\n        lr_beta2: float,\n        lr_eps: float,\n        lr_weight_decay: float,\n        ema_beta: float,\n        ema_power: float,\n        model: nn.Module,\n    ):\n        super().__init__()\n        self.lr = lr\n        self.lr_beta1 = lr_beta1\n        self.lr_beta2 = lr_beta2\n        self.lr_eps = lr_eps\n        self.lr_weight_decay = lr_weight_decay\n        self.model = model\n        self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)\n\n    @property\n    def device(self):\n        return next(self.model.parameters()).device\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.AdamW(\n            list(self.model.parameters()),\n            lr=self.lr,\n            betas=(self.lr_beta1, self.lr_beta2),\n            eps=self.lr_eps,\n            weight_decay=self.lr_weight_decay,\n        )\n        return optimizer\n\n    def training_step(self, batch, batch_idx):\n        wave = batch\n        loss = self.model(wave)\n        self.log(\"train_loss\", loss, sync_dist=True)\n        \n        # Update EMA model and log decay\n        self.model_ema.update()\n        self.log(\"ema_decay\", self.model_ema.get_current_decay(), sync_dist=True)\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        wave = batch\n        loss = self.model_ema(wave)\n        self.log(\"valid_loss\", loss, sync_dist=True)\n        return loss\n\n\n\"\"\" Datamodule \"\"\"\n\nclass Datamodule(pl.LightningDataModule):\n    def __init__(\n        self,\n        dataset,\n        *,\n        val_split: float,\n        batch_size: int,\n        num_workers: int,\n        pin_memory: bool = False,\n        **kwargs: int,\n    ) -> None:\n        super().__init__()\n        self.dataset = dataset\n        self.val_split = val_split\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.pin_memory = pin_memory\n        self.data_train: Any = None\n        self.data_val: Any = None\n\n    def setup(self, stage: Any = None) -> None:\n        split = [1.0 - self.val_split, self.val_split]\n        self.data_train, self.data_val = fractional_random_split(self.dataset, split)\n\n    def get_dataloader(self, dataset) -> DataLoader:\n        return DataLoader(\n            dataset=dataset,            \n            batch_size=self.batch_size,\n            num_workers=self.num_workers,\n            pin_memory=self.pin_memory,\n            shuffle=True,\n            prefetch_factor=2,\n        )\n\n    def train_dataloader(self) -> DataLoader:\n        return self.get_dataloader(self.data_train)\n\n    def val_dataloader(self) -> DataLoader:\n        return self.get_dataloader(self.data_val)\n\n\n\"\"\" Callbacks \"\"\"\n\ndef get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:\n    \"\"\"Safely get Weights&Biases logger from Trainer.\"\"\"\n\n    if isinstance(trainer.logger, WandbLogger):\n        return trainer.logger\n\n    if isinstance(trainer.logger, LoggerCollection):\n        for logger in trainer.logger:\n            if isinstance(logger, WandbLogger):\n                return logger\n\n    print(\"WandbLogger not found.\")\n    return None\n\n\ndef log_wandb_audio_batch(\n    logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n):\n    num_items = samples.shape[0]\n    samples = rearrange(samples, \"b c t -> b t c\").detach().cpu().numpy()\n    logger.log(\n        {\n            f\"sample_{idx}_{id}\": wandb.Audio(\n                samples[idx],\n                caption=caption,\n                sample_rate=sampling_rate,\n            )\n            for idx in range(num_items)\n        }\n    )\n\n\ndef log_wandb_audio_spectrogram(\n    logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = \"\"\n):\n    num_items = samples.shape[0]\n    samples = samples.detach().cpu()\n    transform = torchaudio.transforms.MelSpectrogram(\n        sample_rate=sampling_rate,\n        n_fft=1024,\n        hop_length=512,\n        n_mels=80,\n        center=True,\n        norm=\"slaney\",\n    )\n\n    def get_spectrogram_image(x):\n        spectrogram = transform(x[0])\n        image = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0)\n        trace = [go.Heatmap(z=image, colorscale=\"viridis\")]\n        layout = go.Layout(\n            yaxis=dict(title=\"Mel Bin (Log Frequency)\"),\n            xaxis=dict(title=\"Frame\"),\n            title_font_size=10,\n        )\n        fig = go.Figure(data=trace, layout=layout)\n        return fig\n\n    logger.log(\n        {\n            f\"mel_spectrogram_{idx}_{id}\": get_spectrogram_image(samples[idx])\n            for idx in range(num_items)\n        }\n    )\n\n\nclass SampleLogger(Callback):\n    def __init__(\n        self,\n        num_items: int,\n        channels: int,\n        sampling_rate: int,\n        sampling_steps: List[int],\n        use_ema_model: bool,\n        length: int,\n    ) -> None:\n        self.num_items = num_items\n        self.channels = channels\n        self.sampling_rate = sampling_rate\n        self.sampling_steps = sampling_steps\n        self.use_ema_model = use_ema_model\n        self.log_next = False\n        self.length = length\n\n\n    def on_validation_epoch_start(self, trainer, pl_module):\n        self.log_next = True\n\n    def on_validation_batch_start(\n        self, trainer, pl_module, batch, batch_idx, dataloader_idx\n    ):\n        if self.log_next and trainer.logger: # only log if logger present in config\n            self.log_sample(trainer, pl_module, batch)\n            self.log_next = False\n\n    @torch.no_grad()\n    def log_sample(self, trainer, pl_module, batch):\n        is_train = pl_module.training\n        if is_train:\n            pl_module.eval()\n\n        # Get wandb logger\n        wandb_logger = get_wandb_logger(trainer).experiment\n\n        model = pl_module.model\n\n        if self.use_ema_model:\n            model = pl_module.model_ema.ema_model\n\n\n        # Get noise for diffusion inference\n        noise = torch.randn(\n            (self.num_items, self.channels, self.length), device=pl_module.device\n        )\n\n        for steps in self.sampling_steps:\n            samples = model.sample(\n                noise,\n                num_steps=steps,\n            )\n            log_wandb_audio_batch(\n                logger=wandb_logger,\n                id=\"sample\",\n                samples=samples,\n                sampling_rate=self.sampling_rate,\n                caption=f\"Sampled in {steps} steps\",\n            )\n            log_wandb_audio_spectrogram(\n                logger=wandb_logger,\n                id=\"sample\",\n                samples=samples,\n                sampling_rate=self.sampling_rate,\n                caption=f\"Sampled in {steps} steps\",\n            )\n\n        if is_train:\n            pl_module.train()\n"
  },
  {
    "path": "main/utils.py",
    "content": "import logging\nimport os\nimport warnings\nfrom typing import Callable, List, Optional, Sequence\n\nimport pkg_resources  # type: ignore\nimport pytorch_lightning as pl\nimport rich.syntax\nimport rich.tree\nimport torch\nfrom omegaconf import DictConfig, OmegaConf\nfrom pytorch_lightning import Callback\nfrom pytorch_lightning.utilities import rank_zero_only\n\n\"\"\" Training Utils\"\"\"\n\ndef get_logger(name=__name__) -> logging.Logger:\n    \"\"\"Initializes multi-GPU-friendly python command line logger.\"\"\"\n\n    logger = logging.getLogger(name)\n\n    # this ensures all logging levels get marked with the rank zero decorator\n    # otherwise logs would get multiplied for each GPU process in multi-GPU setup\n    for level in (\n        \"debug\",\n        \"info\",\n        \"warning\",\n        \"error\",\n        \"exception\",\n        \"fatal\",\n        \"critical\",\n    ):\n        setattr(logger, level, rank_zero_only(getattr(logger, level)))\n\n    return logger\n\n\nlog = get_logger(__name__)\n\n\ndef extras(config: DictConfig) -> None:\n    \"\"\"Applies optional utilities, controlled by config flags.\n    Utilities:\n    - Ignoring python warnings\n    - Rich config printing\n    \"\"\"\n\n    # disable python warnings if <config.ignore_warnings=True>\n    if config.get(\"ignore_warnings\"):\n        log.info(\"Disabling python warnings! <config.ignore_warnings=True>\")\n        warnings.filterwarnings(\"ignore\")\n\n    # pretty print config tree using Rich library if <config.print_config=True>\n    if config.get(\"print_config\"):\n        log.info(\"Printing config tree with Rich! <config.print_config=True>\")\n        print_config(config, resolve=True)\n\n\n@rank_zero_only\ndef print_config(\n    config: DictConfig,\n    print_order: Sequence[str] = (\n        \"datamodule\",\n        \"model\",\n        \"callbacks\",\n        \"logger\",\n        \"trainer\",\n    ),\n    resolve: bool = True,\n) -> None:\n    \"\"\"Prints content of DictConfig using Rich library and its tree structure.\n    Args:\n        config (DictConfig): Configuration composed by Hydra.\n        print_order (Sequence[str], optional): Determines in what order config components are printed.\n        resolve (bool, optional): Whether to resolve reference fields of DictConfig.\n    \"\"\"\n\n    style = \"dim\"\n    tree = rich.tree.Tree(\"CONFIG\", style=style, guide_style=style)\n\n    quee = []\n\n    for field in print_order:\n        quee.append(field) if field in config else log.info(\n            f\"Field '{field}' not found in config\"\n        )\n\n    for field in config:\n        if field not in quee:\n            quee.append(field)\n\n    for field in quee:\n        branch = tree.add(field, style=style, guide_style=style)\n\n        config_group = config[field]\n        if isinstance(config_group, DictConfig):\n            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)\n        else:\n            branch_content = str(config_group)\n\n        branch.add(rich.syntax.Syntax(branch_content, \"yaml\"))\n\n    rich.print(tree)\n\n    with open(\"config_tree.log\", \"w\") as file:\n        rich.print(tree, file=file)\n\n\n@rank_zero_only\ndef log_hyperparameters(\n    config: DictConfig,\n    model: pl.LightningModule,\n    datamodule: pl.LightningDataModule,\n    trainer: pl.Trainer,\n    callbacks: List[pl.Callback],\n    logger: List[pl.loggers.LightningLoggerBase],\n) -> None:\n    \"\"\"Controls which config parts are saved by Lightning loggers.\n    Additionaly saves:\n    - number of model parameters\n    \"\"\"\n\n    if not trainer.logger:\n        return\n\n    hparams = {}\n\n    # choose which parts of hydra config will be saved to loggers\n    hparams[\"model\"] = config[\"model\"]\n\n    # save number of model parameters\n    hparams[\"model/params/total\"] = sum(p.numel() for p in model.parameters())\n    hparams[\"model/params/trainable\"] = sum(\n        p.numel() for p in model.parameters() if p.requires_grad\n    )\n    hparams[\"model/params/non_trainable\"] = sum(\n        p.numel() for p in model.parameters() if not p.requires_grad\n    )\n\n    hparams[\"datamodule\"] = config[\"datamodule\"]\n    hparams[\"trainer\"] = config[\"trainer\"]\n\n    if \"seed\" in config:\n        hparams[\"seed\"] = config[\"seed\"]\n    if \"callbacks\" in config:\n        hparams[\"callbacks\"] = config[\"callbacks\"]\n\n    hparams[\"pacakges\"] = get_packages_list()\n\n    # send hparams to all loggers\n    trainer.logger.log_hyperparams(hparams)\n\n\ndef finish(\n    config: DictConfig,\n    model: pl.LightningModule,\n    datamodule: pl.LightningDataModule,\n    trainer: pl.Trainer,\n    callbacks: List[pl.Callback],\n    logger: List[pl.loggers.LightningLoggerBase],\n) -> None:\n    \"\"\"Makes sure everything closed properly.\"\"\"\n\n    # without this sweeps with wandb logger might crash!\n    for lg in logger:\n        if isinstance(lg, pl.loggers.wandb.WandbLogger):\n            import wandb\n\n            wandb.finish()\n\n\ndef get_packages_list() -> List[str]:\n    return [f\"{p.project_name}=={p.version}\" for p in pkg_resources.working_set]\n\n\ndef retry_if_error(fn: Callable, num_attemps: int = 10):\n    for attempt in range(num_attemps):\n        try:\n            return fn()\n        except:\n            print(f\"Retrying, attempt {attempt+1}\")\n            pass\n    return fn()\n\n\nclass SavePytorchModelAndStopCallback(Callback):\n    def __init__(self, path: str, attribute: Optional[str] = None):\n        self.path = path\n        self.attribute = attribute\n\n    def on_train_start(self, trainer, pl_module):\n        model, path = pl_module, self.path\n        if self.attribute is not None:\n            assert_message = \"provided model attribute not found in pl_module\"\n            assert hasattr(pl_module, self.attribute), assert_message\n            model = getattr(\n                pl_module, self.attribute, hasattr(pl_module, self.attribute)\n            )\n        # Make dir if not existent\n        os.makedirs(os.path.split(path)[0], exist_ok=True)\n        # Save model\n        torch.save(model, path)\n        log.info(f\"PyTorch model saved at: {path}\")\n        # Stop trainer\n        trainer.should_stop = True"
  },
  {
    "path": "saved_models/.gitkeep",
    "content": ""
  },
  {
    "path": "setup/environment.yml",
    "content": "name: tiny-audio-diffusion\n\ndependencies:\n  - python=3.10\n  - pip\n  - pip:\n    - -r requirements.txt"
  },
  {
    "path": "setup/requirements.txt",
    "content": "torch>=2.0.1\ntorchaudio>=2.0.2\npytorch-lightning==1.7.7\ntorchmetrics==0.11.4\npython-dotenv\nhydra-core\nhydra-colorlog\nwandb\nauraloss\nyt-dlp\ndatasets\npyloudnorm\neinops\nomegaconf\nrich\nplotly\nlibrosa\ntransformers\neng-to-ipa\nema-pytorch\npy7zr\nnotebook\nmatplotlib\nipykernel\ngradio\n\n# k-diffusion\n# v-diffusion-pytorch\n\naudio-diffusion-pytorch==0.1.3\naudio-encoders-pytorch\naudio-data-pytorch\nquantizer-pytorch\ndifformer-pytorch\na-transformers-pytorch\n"
  },
  {
    "path": "train.py",
    "content": "import os\nimport dotenv\nimport hydra\nimport pytorch_lightning as pl\nfrom main import utils\nfrom omegaconf import DictConfig, open_dict\n# import torch # use if direct checkpoint load required (see line 87)\n\n\n# Load environment variables from `.env`.\ndotenv.load_dotenv(override=True)\nlog = utils.get_logger(__name__)\n\n\n@hydra.main(config_path=\".\", config_name=\"config.yaml\", version_base=None)\ndef main(config: DictConfig) -> None:\n\n    # Logs config tree\n    utils.extras(config)\n\n    # Apply seed for reproducibility\n    pl.seed_everything(config.seed)\n\n    # Initialize datamodule\n    log.info(f\"Instantiating datamodule <{config.datamodule._target_}>.\")\n    datamodule = hydra.utils.instantiate(config.datamodule, _convert_=\"partial\")\n\n    # Initialize model\n    log.info(f\"Instantiating model <{config.model._target_}>.\")\n    model = hydra.utils.instantiate(config.model, _convert_=\"partial\")\n\n    # Initialize all callbacks (e.g. checkpoints, early stopping)\n    callbacks = []\n\n    # If save is provided add callback that saves and stops, to be used with +ckpt\n    if \"save\" in config:\n        # Ignore loggers and other callbacks\n        with open_dict(config):\n            config.pop(\"loggers\")\n            config.pop(\"callbacks\")\n            config.trainer.num_sanity_val_steps = 0\n        attribute, path = config.get(\"save\"), config.get(\"ckpt_dir\")\n        filename = os.path.join(path, f\"{attribute}.pt\")\n        callbacks += [utils.SavePytorchModelAndStopCallback(filename, attribute)]\n\n    if \"callbacks\" in config:\n        for _, cb_conf in config[\"callbacks\"].items():\n            if \"_target_\" in cb_conf:\n                log.info(f\"Instantiating callback <{cb_conf._target_}>.\")\n                callbacks.append(hydra.utils.instantiate(cb_conf, _convert_=\"partial\"))\n\n    # Initialize loggers (e.g. wandb)\n    loggers = []\n    if \"loggers\" in config:\n        for _, lg_conf in config[\"loggers\"].items():\n            if \"_target_\" in lg_conf:\n                log.info(f\"Instantiating logger <{lg_conf._target_}>.\")\n                # Sometimes wandb throws error if slow connection...\n                logger = utils.retry_if_error(\n                    lambda: hydra.utils.instantiate(lg_conf, _convert_=\"partial\")\n                )\n                loggers.append(logger)\n\n    # Initialize trainer\n    log.info(f\"Instantiating trainer <{config.trainer._target_}>.\")\n    trainer = hydra.utils.instantiate(\n        config.trainer, callbacks=callbacks, logger=loggers, _convert_=\"partial\"\n    )\n\n    # Send some parameters from config to all lightning loggers\n    log.info(\"Logging hyperparameters!\")\n    utils.log_hyperparameters(\n        config=config,\n        model=model,\n        datamodule=datamodule,\n        trainer=trainer,\n        callbacks=callbacks,\n        logger=loggers,\n    )\n\n    # Train with checkpoint if present, otherwise from start\n    if \"ckpt\" in config:\n        ckpt = config.get(\"ckpt\")\n        log.info(f\"Starting training from {ckpt}\")\n        trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt)\n\n        # # Alternative model load method\n        # # Use if loading from checkpoint with pl trainer causes GPU memory spike (CUDA out of memory).\n        # checkpoint = torch.load(config.get(\"ckpt\"), map_location='cpu')['state_dict']\n        # model.load_state_dict(checkpoint)\n        # trainer.fit(model=model, datamodule=datamodule)\n    else:\n        log.info(\"Starting training.\")\n        trainer.fit(model=model, datamodule=datamodule)\n\n    # Make sure everything closed properly\n    log.info(\"Finalizing!\")\n    utils.finish(\n        config=config,\n        model=model,\n        datamodule=datamodule,\n        trainer=trainer,\n        callbacks=callbacks,\n        logger=loggers,\n    )\n\n    # Print path to best checkpoint\n    if (\n        not config.trainer.get(\"fast_dev_run\")\n        and config.get(\"train\")\n        and not config.get(\"save\")\n    ):\n        log.info(f\"Best model ckpt at {trainer.checkpoint_callback.best_model_path}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]