Repository: CPJKU/beat_this Branch: main Commit: 9d787b9797ea Files: 29 Total size: 172.9 KB Directory structure: gitextract_r0d2sw0_/ ├── .github/ │ └── workflows/ │ └── pypi.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── beat_this/ │ ├── __init__.py │ ├── cli.py │ ├── dataset/ │ │ ├── __init__.py │ │ ├── augment.py │ │ ├── dataset.py │ │ └── mmnpz.py │ ├── inference.py │ ├── model/ │ │ ├── __init__.py │ │ ├── beat_tracker.py │ │ ├── loss.py │ │ ├── pl_module.py │ │ ├── postprocessor.py │ │ └── roformer.py │ ├── preprocessing.py │ └── utils.py ├── beat_this_example.ipynb ├── hubconf.py ├── launch_scripts/ │ ├── clean_checkpoints.py │ ├── compute_paper_metrics.py │ ├── preprocess_audio.py │ └── train.py ├── pyproject.toml ├── requirements.txt └── tests/ └── test_inference.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/pypi.yml ================================================ # This workflow will upload a Python Package to PyPI when a release is created # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries # This workflow uses actions that are not certified by GitHub. # They are provided by a third-party and are governed by # separate terms of service, privacy policy, and support # documentation. name: Upload Python Package on: release: types: [published] permissions: contents: read jobs: release-build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - uses: actions/setup-python@v5 with: python-version: "3.x" - name: Build release distributions run: | python -m pip install build python -m build - name: Upload distributions uses: actions/upload-artifact@v4 with: name: release-dists path: dist/ pypi-publish: runs-on: ubuntu-latest needs: - release-build permissions: # IMPORTANT: this permission is mandatory for trusted publishing id-token: write # Dedicated environments with protections for publishing are strongly recommended. # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules environment: name: pypi url: https://pypi.org/project/beat-this/${{ github.event.release.name }} steps: - name: Retrieve release distributions uses: actions/download-artifact@v5 with: name: release-dists path: dist/ - name: Publish release distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: packages-dir: dist/ ================================================ FILE: .gitignore ================================================ __pycache__/ *.py[cod] *$py.class data/ checkpoints/ lightning_logs/ wandb/ .vscode/ beat_this.egg-info/ build/ ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project are documented below. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [1.1.0] - 2026-04-14 - Clarified installation instructions for madmom and mir_eval - Load checkpoints with `weights_only=True` when supported - Fix checkpoint downloads after server-side update - Provide separate `infer_beat_numbers()` function - Command-line tool: Support saving raw activations / logits - Training script: Support resuming from previous checkpoint - Migrate to pyproject.toml (thanks to @JacobLinCool) - Support non-CUDA accelerator chips (thanks to @tillt) - Published on PyPI (thanks to @MarvinSchenkel) ## [1.0] - 2024-10-18 - Initial release ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Institute of Computational Perception, JKU Linz, Austria Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Beat This! Official implementation of the beat tracker from the ISMIR 2024 paper "[Beat This! Accurate Beat Tracking Without DBN Postprocessing](https://arxiv.org/abs/2407.21658)" by Francesco Foscarin, Jan Schlüter and Gerhard Widmer. * [Inference](#inference) * [Available models](#available-models) * [Data](#data) * [Reproducing metrics from the paper](#reproducing-metrics-from-the-paper) * [Training](#training) * [Reusing the loss](#reusing-the-loss) * [Reusing the model](#reusing-the-model) * [Citation](#citation) ## Inference To predict beats for audio files, you can either use our command line tool or call the beat tracker from Python. Both have the same requirements unless you go for the online demo. ### Online demo To process a small set of audio files without installing anything, [open our example notebook in Google Colab](https://colab.research.google.com/github/CPJKU/beat_this/blob/main/beat_this_example.ipynb) and follow the instructions. ### Requirements The beat tracker requires Python with a set of packages installed: 1. [Install PyTorch](https://pytorch.org/get-started/locally/) 2.0 or later following the instructions for your platform. 2. Install further modules with `pip install tqdm einops soxr rotary-embedding-torch`. (If using conda, we still recommend pip. You may try installing `soxr-python` and `einops` from conda-forge, but `rotary-embedding-torch` is only on PyPI.) 3. To read other audio formats than `.wav`, install `ffmpeg` or another supported backend for `torchaudio`. (`ffmpeg` can be installed via conda or via your operating system.) Finally, install our beat tracker with: ```bash pip install beat-this ``` For the development version, use: ```bash pip install https://github.com/CPJKU/beat_this/archive/main.zip ``` ### Command line Along with the python package, a command line application called `beat_this` is installed. For a full documentation of the command line options, run: ```bash beat_this --help ``` The basic usage is: ```bash beat_this path/to/audio.file -o path/to/output.beats ``` To process multiple files, specify multiple input files or directories, and give an output directory instead: ```bash beat_this path/to/*.mp3 path/to/whole_directory/ -o path/to/output_directory ``` The beat tracker will use the first GPU in your system by default, and fall back to CPU if PyTorch does not have CUDA access. With `--gpu=2`, it will use the third GPU, and with `--gpu=-1` it will force the CPU. For recent GPUs, passing `--float16` may improve speed. If you have a lot of files to process, you can distribute the load over multiple processes by running the same command multiple times with `--touch-first`, `--skip-existing` and potentially different options for `--gpu`: ```bash for gpu in {0..3}; do beat_this input_dir -o output_dir --touch-first --skip-existing --gpu=$gpu & done ``` If you want to use the DBN for postprocessing, add `--dbn`. The DBN parameters are the default ones from madmom. This requires installing the `madmom` package (with `pip install git+https://github.com/CPJKU/madmom.git`, as the current version on PyPI only supports Python<3.10 and numpy<1.20). ### Python class If you are a Python user, you can directly use the `beat_this.inference` module. First, instantiate an instance of the `File2Beats` class that encapsulates the model along with pre- and postprocessing: ```python from beat_this.inference import File2Beats file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False) ``` To obtain a list of beats and downbeats for an audio file, run: ```python audio_path = "path/to/audio.file" beats, downbeats = file2beats(audio_path) ``` Optionally, you can produce a `.beats` file (e.g., for importing into [Sonic Visualizer](https://www.sonicvisualiser.org/)): ```python from beat_this.utils import save_beat_tsv outpath = "path/to/output.beats" save_beat_tsv(beats, downbeats, outpath) ``` If you already have an audio tensor loaded, instead of `File2Beats`, use `Audio2Beats` and pass the tensor and its sample rate. We also provide `Audio2Frames` for framewise logits and `Spect2Frames` for spectrogram inputs. ## Available models Models are available for manual download at [our cloud space](https://cloud.cp.jku.at/index.php/s/7ik4RrBKTS273gp), but will also be downloaded automatically by the above inference code. By default, the inference will use `final0`, but it is possible to select another model via a command line option (`--model`) or Python parameter (`checkpoint_path`). Main models: * `final0`, `final1`, `final2`: Our main model, trained on all data except the GTZAN dataset, with three different seeds. This corresponds to "Our system" in Table 2 of the paper. About 78 MB per model. * `small0`, `small1`, `small2`: A smaller model, again trained on all data except GTZAN, with three different seeds. This corresponds to "smaller model" in Table 2 of the paper. About 8.1 MB per model. * `single_final0`, `single_final1`, `single_final2`: Our main model, trained on the single split described in Section 4.1 of the paper, with three different seeds. This corresponds to "Our system" in Table 3 of the paper. About 78 MB per model. * `fold0`, `fold1`, `fold2`, `fold3`, `fold4`, `fold5`, `fold6`, `fold7`: Our main model, trained in the 8-fold cross-validation setting with a single seed per fold. This corresponds to "Our" in Table 1 of the paper. About 78 MB per model. Other models, available mainly for result reproducibility: * `hung0`, `hung1`, `hung2`: A model trained on all the data used by the "Modeling Beats and Downbeats with a Time-Frequency Transformer" system by Hung et al. (except GTZAN dataset), with three different seeds. This corresponds to "limited to data of [10]" in Table 2 of the paper. * the other models used for the ablation studies in Table 3, all trained with 3 seeds on the single split described in Section 4.1 of the paper: * `single_notempoaug0`, `single_notempoaug1`, `single_notempoaug2` * `single_nosumhead0`, `single_nosumhead1`, `single_nosumhead2` * `single_nomaskaug0`, `single_nomaskaug1`, `single_nomaskaug2` * `single_nopartialt0`, `single_nopartialt1`, `single_nopartialt2` * `single_noshifttol0`, `single_noshifttol1`, `single_noshifttol2` * `single_nopitchaug0`, `single_nopitchaug1`, `single_nopitchaug2` * `single_noshifttolnoweights0`, `single_noshifttolnoweights1`, `single_noshifttolnoweights0` Please be aware that the results may be unfairly good if you run inference on any file from the training datasets. For example, an evaluation with `final*` or `small*` can only be performed fairly on GTZAN or other datasets we didn't consider in our paper. If you need to run an evaluation on some datasets we used other than GTZAN, consider targeting the validation part of the single split (with `single_final*`), or of the 8-fold cross-validation (with `fold*`). All the models are provided as PyTorch Lightning checkpoints, stripped of the optimizer state to reduce their size. This is useful for reproducing the paper results or verifying the hyperparameters (stored in the checkpoint under `hyper_parameters` and `datamodule_hyper_parameters`). During inference, PyTorch Lighting is not used, and the checkpoints are converted and loaded into vanilla PyTorch modules. ## Data ### Annotations All annotations we used to train our models are available [in a separate GitHub repo](https://github.com/CPJKU/beat_this_annotations). Note that if you want to obtain the exact paper results, you should use [version 1.0](https://github.com/CPJKU/beat_this_annotations/releases/tag/v1.0). Other releases with corrected annotations may be published in the future. To use the annotations for training or evaluation, you first need to download and extract or clone the annotations repo to `data/annotations`: ```bash mkdir -p data git clone https://github.com/CPJKU/beat_this_annotations data/annotations # cd data/annotations; git checkout v1.0 # optional ``` ### Spectrograms The spectrograms used for training are released [as a Zenodo dataset](https://zenodo.org/records/13922116). They are distributed as a separate .zip file per dataset, each holding a .npz file with the spectrograms. For evaluation of the test set, download `gtzan.zip`; for training and evaluation of the validation set, download all (except `beat_this_annotations.zip`). Extract all .zip files into `data/audio/spectrograms`, so that you have, for example, `data/audio/spectrograms/gtzan.npz`. As an alternative, the code also supports directories of .npy files such as `data/audio/spectrograms/gtzan/gtzan_blues_00000/track.npy`, which you can obtain by unzipping `gtzan.npz`. ### Recreating spectrograms If you have access to the original audio files, or want to add another dataset, create a text file `data/audio_paths.tsv` that has, on each line, the name of a dataset, a tab character, and the path to the audio directory. The corresponding annotations must also be present under `data/annotations`. Install pandas and pedalboard: ```bash pip install pandas pedalboard ``` Then run: ```bash python launch_scripts/preprocess_audio.py ``` It will create monophonic 22 kHz wave files in `data/audio/mono_tracks`, convert those to spectrograms in `data/audio/spectrograms`, and create spectrogram bundles. Intermediary files are kept and will not be recreated when rerunning the script. ## Reproducing metrics from the paper ### Requirements In addition to the [inference requirements](#requirements), computing evaluation metrics requires installing PyTorch Lightning, Pandas, and `mir_eval` (the latter from source, as the current version on PyPI only supports numpy<1.20). ```bash pip install pytorch_lightning pandas pip install https://github.com/mir-evaluation/mir_eval/archive/main.zip ``` You must also obtain and set up the annotations and spectrogram datasets [as indicated above](#data). Specifically, the GTZAN dataset suffices for commands that include `--datasplit test`, while all other datasets are required for commands that include `--datasplit val`. ### Command line #### Compute results on the test set (GTZAN) corresponding to Table 2 in the paper. Main results for our system: ```bash python launch_scripts/compute_paper_metrics.py --models final0 final1 final2 --datasplit test ``` Smaller model: ```bash python launch_scripts/compute_paper_metrics.py --models small0 small1 small2 --datasplit test ``` Hung data: ```bash python launch_scripts/compute_paper_metrics.py --models hung0 hung1 hung2 --datasplit test ``` With DBN (this requires installing the madmom package): ```bash python launch_scripts/compute_paper_metrics.py --models final0 final1 final2 --datasplit test --dbn ``` #### Compute 8-fold cross-validation results, corresponding to Table 1 in the paper. ```bash python launch_scripts/compute_paper_metrics.py --models fold0 fold1 fold2 fold3 fold4 fold5 fold6 fold7 --datasplit val --aggregation-type k-fold ``` #### Compute ablation studies on the validation set of the single split, correponding to Table 3 in the paper. Our system: ```bash python launch_scripts/compute_paper_metrics.py --models single_final0 single_final1 single_final2 --datasplit val ``` No sum head: ```bash python launch_scripts/compute_paper_metrics.py --models single_nosumhead0 single_nosumhead1 single_nosumhead2 --datasplit val ``` No tempo augmentation: ```bash python launch_scripts/compute_paper_metrics.py --models single_notempoaug0 single_notempoaug1 single_notempoaug2 --datasplit val ``` No mask augmentation: ```bash python launch_scripts/compute_paper_metrics.py --models single_nomaskaug0 single_nomaskaug1 single_nomaskaug2 --datasplit val ``` No partial transformers: ```bash python launch_scripts/compute_paper_metrics.py --models single_nopartialt0 single_nopartialt1 single_nopartialt2 --datasplit val ``` No shift tolerance: ```bash python launch_scripts/compute_paper_metrics.py --models single_noshifttol0 single_noshifttol1 single_noshifttol2 --datasplit val ``` No pitch augmentation: ```bash python launch_scripts/compute_paper_metrics.py --models single_nopitchaug0 single_nopitchaug1 single_nopitchaug2 --datasplit val ``` No shift tolerance and no weights: ```bash python launch_scripts/compute_paper_metrics.py --models single_noshifttolnoweights0 single_noshifttolnoweights1 single_noshifttolnoweights2 --datasplit val ``` ## Training ### Requirements The training requirements match the [evaluation requirements](#requirements-1) for the validation set. All 16 datasets and annotations must be [correctly set up](#data). ### Command line #### Train models listed in Table 2 in the paper. Main results for our system (final0, final1, final2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-val done ``` Smaller model (small0, small1, small2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-val --transformer-dim=128 done ``` Hung data (hung0, hung1, hung2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-val --hung-data done ``` #### Train models with 8-fold cross-validation, corresponding to Table 1 in the paper. ```bash for fold in {0..7}; do python launch_scripts/train.py --fold=$fold done ``` #### Train models for the ablation studies, corresponding to Table 3 in the paper. Our system (single_final0, single_final1, single_final2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed done ``` No sum head (single_nosumhead0, single_nosumhead1, single_nosumhead2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-sum-head done ``` No tempo augmentation (single_notempoaug0, single_notempoaug1, single_notempoaug2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-tempo-augmentation done ``` No mask augmentation (single_nomaskaug0, single_nomaskaug1, single_nomaskaug2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-mask-augmentation done ``` No partial transformers (single_nopartialt0, single_nopartialt1, single_nopartialt2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-partial-transformers done ``` No shift tolerance (single_noshifttol0, single_noshifttol1, single_noshifttol2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --loss weighted_bce done ``` No pitch augmentation (single_nopitchaug0, single_nopitchaug1, single_nopitchaug2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --no-pitch-augmentation done ``` No shift tolerance and no weights (single_noshifttolnoweights0, single_noshifttolnoweights1, single_noshifttolnoweights2): ```bash for seed in 0 1 2; do python launch_scripts/train.py --seed=$seed --loss bce done ``` ## Reusing the loss To reuse our shift-invariant binary cross-entropy loss, just copy out the `ShiftTolerantBCELoss` class from [`loss.py`](beat_this/model/loss.py), it does not have any dependencies. ## Reusing the model To reuse the BeatThis model, you have multiple options: ### From the package When installing the `beat_this` package, you can directly import the model class: ``` from beat_this.model.beat_tracker import BeatThis ``` Instantiating this class will give you an untrained model from spectrograms to frame-wise beat and downbeat logits. For a pretrained model, use `load_model`: ``` from beat_this.inference import load_model beat_this = load_model('final0', device='cuda') ``` ### From torch.hub To quickly try the model without installing the package, just install the [requirements for inference](#requirements) and do: ``` import torch beat_this = torch.hub.load('CPJKU/beat_this', 'beat_this', 'final0', device='cuda') ``` ### Copy and paste To copy the BeatThis model into your own project, you will need the [`beat_tracker.py`](beat_this/model/beat_tracker.py) and [`roformer.py`](beat/this/model/roformer.py) files. If you remove the `BeatThis.state_dict()` and `BeatThis._load_from_state_dict()` methods that serve as a workaround for compiled models, then there are no other internal dependencies, only external dependencies (`einops`, `rotary-embedding-torch`). ## Citation ```bibtex @inproceedings{foscarin2024beatthis, author = {Francesco Foscarin and Jan Schl{\"u}ter and Gerhard Widmer}, title = {Beat this! Accurate beat tracking without {DBN} postprocessing}, year = 2024, month = nov, booktitle = {Proceedings of the 25th International Society for Music Information Retrieval Conference (ISMIR)}, address = {San Francisco, CA, United States}, } ``` ================================================ FILE: beat_this/__init__.py ================================================ ================================================ FILE: beat_this/cli.py ================================================ #!/usr/bin/env python3 """ Beat This! command line inference tool. """ import argparse import sys from pathlib import Path import numpy as np import torch try: import tqdm except ImportError: tqdm = None from beat_this.inference import File2File, load_audio from beat_this.utils import save_beat_tsv def get_parser(): parser = argparse.ArgumentParser( description="Detects beats in given audio files with a Beat This! model." ) parser.add_argument( "inputs", type=str, nargs="+", help="An audio file to process, or a directory of such files. Can be given multiple times.", ) parser.add_argument( "--model", type=str, help="Name, path or URL of checkpoint to use, will be downloaded if needed (default:%(default)s).", default="final0", ) parser.add_argument( "--output", "-o", type=str, default=None, help="Output file name for a single input file, or output directory for multiple input files. If omitted, outputs are saved next to each input file by replacing or appending a suffix (see --suffix and --append).", ) parser.add_argument( "--suffix", "-s", type=str, default=".beats", help="Suffix for output file names (default: %(default)s). Also see --append. Ignored if an explicit output file name is given.", ) parser.add_argument( "--append", action="store_true", help="If given, append suffix to output file names instead of replacing the existing suffix. Ignored if an explicit output file name is given.", ) parser.add_argument( "--skip-existing", action="store_true", help="If given, do not overwrite existing output files, but skip them.", ) parser.add_argument( "--touch-first", action="store_true", help="If given, create empty output file before processing. Combined with --skip-existing, allows to run multiple processes in parallel on the same set of files.", ) parser.add_argument( "--dbn", default=False, action=argparse.BooleanOptionalAction, help="Override the option to use madmom's postprocessing DBN.", ) parser.add_argument( "--gpu", type=int, default=0, help="Which GPU to use (not the number of GPUs), or -1 for CPU. Ignored if CUDA is not available. (default: %(default)s)", ) parser.add_argument( "--float16", action="store_true", help="If given, uses half precision floating point arithmetics. Required for flash attention on GPU. (default: %(default)s)", ) parser.add_argument( "--activations", action="store_true", help="If given, saves the raw activations with a .npy suffix.", ) return parser def derive_output_path(input_path, suffix, append, output=None, parent=None): """ Determine the output file name for `input_path` using the given suffix. If given, `output` is the base directory for outputs, and `parent` is the directory that was given on the command line. """ # output directory if output is None: output_path = input_path else: if parent is not None: input_path = input_path.relative_to(parent) else: input_path = input_path.name output_path = output / input_path # suffix if append: return output_path.parent / (output_path.name + suffix) else: return output_path.with_suffix(suffix) def run( inputs, model, output, suffix, append, skip_existing, touch_first, dbn, gpu, float16, activations, ): # determine device if torch.cuda.is_available() and gpu >= 0: device = torch.device(f"cuda:{gpu}") else: device = torch.device("cpu") # prepare model file2file = File2File(model, device, float16, dbn) if activations: def process(audiofile, outfile): wav, sr = load_audio(audiofile) spect = file2file.signal2spect(wav, sr) beat_logits, downbeat_logits = file2file.spect2frames(spect) np.save( outfile.with_suffix(".npy"), np.vstack([beat_logits.cpu().numpy(), downbeat_logits.cpu().numpy()]), ) beats, downbeats = file2file.frames2beats(beat_logits, downbeat_logits) save_beat_tsv(beats, downbeats, outfile) else: process = file2file # process inputs inputs = [Path(item) for item in inputs] if output is not None: output = Path(output) if len(inputs) == 1 and not inputs[0].is_dir(): # special case: single input file if output is None or output.is_dir(): output = derive_output_path(inputs[0], suffix, append, output) process(inputs[0], output) else: # multiple inputs: first collect tasks so we can have a progress bar tasks = [] for item in inputs: if item.is_dir(): for fn in item.rglob("*"): if not fn.name.endswith(suffix) and not fn.is_dir(): output_path = derive_output_path( fn, suffix, append, output, parent=item ) if not skip_existing or not output_path.exists(): tasks.append((fn, output_path)) else: tasks.append((item, derive_output_path(item, suffix, append, output))) # then process all of them if tqdm is not None: tasks = tqdm.tqdm(tasks) for item, output in tasks: if touch_first: try: output.touch(exist_ok=not skip_existing) except FileExistsError: continue elif skip_existing and output.exists(): continue try: process(item, output) except Exception: print( f'Could not process "{item}". Rerun with this file alone for details.', file=sys.stderr, ) def main(): run(**vars(get_parser().parse_args())) if __name__ == "__main__": sys.exit(main()) ================================================ FILE: beat_this/dataset/__init__.py ================================================ from beat_this.dataset.dataset import BeatDataModule ================================================ FILE: beat_this/dataset/augment.py ================================================ import numpy as np import torch def augment_pitchtempo(item, augmentations): """ Apply a randomly chosen pitch or tempo augmentation to the item. Parameters: item: dict A dictionary representing the item to be augmented. It should contain the following keys: - 'spect_path': The path to the the unaugmented spectrogram file. If pitch or tempo augmentation is applied, the 'spect_path' key will be updated. augmentations: dict A dictionary containing the augmentations to be applied. It can contain either or both of the following keys: - 'pitch': A dictionary with 'min' and 'max' keys specifying the range of pitch shifting in semitones. - 'tempo': A dictionary with 'min' and 'max' keys specifying the range of time stretching factors. Returns: item: dict The item after applying the augmentation. If a pitch or tempo augmentation was applied, the 'spect_path' key and the annotations will be updated. """ # Handle pitch and tempo augmentations if "pitch" in augmentations and "tempo" in augmentations: # if both pitch and tempo are enabled, pick one of them if np.random.randint(2) == 0: # pitch item = augment_pitch(item, augmentations["pitch"]) else: # tempo item = augment_tempo(item, augmentations["tempo"]) elif "pitch" in augmentations: item = augment_pitch(item, augmentations["pitch"]) elif "tempo" in augmentations: item = augment_tempo(item, augmentations["tempo"]) return item def augment_pitch(item, pitch_params): """Apply pitch shifting to the item.""" semitones = np.random.randint(pitch_params["min"], pitch_params["max"] + 1) item = shift_filename(item, semitones) item = shift_annotations(item, semitones) return item def augment_tempo(item, tempo_params): """Apply time stretching to the item.""" percentage = np.random.choice( np.arange(tempo_params["min"], tempo_params["max"] + 1, tempo_params["stride"]) ) item = stretch_filename(item, percentage) item = stretch_annotations(item, percentage) return item def stretch_annotations(item, percentage): """Apply time stretching to the item's annotations.""" if not percentage: return item # percentage is the amount by which the *tempo* changes factor = 1.0 + percentage / 100 item = dict(item) item["beat_time"] = item["beat_time"] / factor return item def shift_annotations(item, semitones): """Apply pitch shifting to the item's annotations.""" return item def stretch_filename(item, percentage): """Derive filename of precomputed time stretched version.""" spect_path = item["spect_path"] if percentage: stem = spect_path.stem + f"_ts{percentage}" spect_path = spect_path.with_stem(stem) return {**item, "spect_path": spect_path} def shift_filename(item, semitones): """Derive filename of precomputed pitch shifted version.""" spect_path = item["spect_path"] if semitones: stem = spect_path.stem + f"_ps{semitones}" spect_path = spect_path.with_stem(stem) return {**item, "spect_path": spect_path} def number_of_precomputed_augmentations(augmentations): """Return the number of augmentations that correspond to a precomputed file.""" counter = 1 for method, params in augmentations.values(): if method in ("pitch"): counter += params["max"] - params["min"] elif method in ("tempo"): counter += (params["max"] - params["min"]) // params["stride"] return counter def precomputed_augmentation_filenames(augmentations, ext="npy"): """Return the filenames of the precomputed augmentations. Parameters: augmentations: dict A dictionary containing the augmentations to be applied. It can contain either or both of the following keys: - 'pitch': A dictionary with 'min' and 'max' keys specifying the range (including boundaries) of pitch shifting in semitones. - 'tempo': A dictionary with 'min' and 'max' keys specifying the range (including boundaries) of time stretching factors; and a 'stride' key specifying the step size. """ filenames = [f"track.{ext}"] for method, params in augmentations.items(): if method == "pitch": for semitones in range(params["min"], params["max"] + 1): if semitones == 0: continue filenames.append(f"track_ps{semitones}.{ext}") elif method == "tempo": for percentage in range(params["min"], params["max"] + 1, params["stride"]): if percentage == 0: continue filenames.append(f"track_ts{percentage}.{ext}") return filenames def augment_mask_(spect, augmentations: dict, fps: int): """ Apply the given masking operations to the spectrogram. The spectrogram is modified in place. Parameters: spect: ndarray The input spectrogram to which the mask will be applied. It is a 2D array where the first dimension represents time frames and the second dimension represents frequency bins. augmentations: dict A dictionary containing all the augmentations. If there is no "mask" key, this function returns the unmodified spectrogram. If "mask" key is present, the value is another dictionary which must include the following keys: - 'kind': The type of mask to apply. Choices: 'permute' and 'zero'. - 'min_count' and 'max_count': The minimum and maximum number of times the mask should be applied. - 'min_len' and 'max_len': The minimum and maximum length of the mask, expressed in seconds. - 'min_parts' and 'max_parts': The minimum and maximum number of parts in which each masked section is segmented. These are then randomly reordered. If 'kind'='permute' this parameter is not used. fps: int The frames per second of the audio. This is used to convert 'min_len' and 'max_len' from seconds to frames. Returns: spect: ndarray The spectrogram after applying the mask. """ if "mask" in augmentations: mask_params = augmentations["mask"] count = np.random.randint( mask_params["min_count"], mask_params["max_count"] + 1 ) # convert min_len and max_len in frames min_len = int(mask_params["min_len"] * fps) max_len = int(mask_params["max_len"] * fps) # apply the masking a number of time specified by count for _ in range(count): length = np.random.randint(min_len, max_len + 1) start = np.random.randint(0, len(spect) - length) apply_mask_excerpt( spect[start : start + length], mask_params["kind"], mask_params["min_parts"], mask_params["max_parts"], ) return spect def apply_mask_excerpt(excerpt, kind, min_parts, max_parts): """Apply a mask operation of the given kind in-place to the given tensor.""" if kind == "permute": num_parts = np.random.randint(min_parts, max_parts + 1) choices = len(excerpt) num_parts = min(num_parts, choices + 1) positions = np.random.choice(choices, num_parts - 1, replace=False) positions.sort() if isinstance(excerpt, np.ndarray): parts = np.split(excerpt, positions) else: parts = ( [excerpt[: positions[0]]] + [excerpt[a:b] for a, b in zip(positions[:-1], positions[1:])] + [excerpt[positions[-1] :]] ) parts = [parts[idx] for idx in np.random.permutation(num_parts)] if isinstance(excerpt, np.ndarray): excerpt[:] = np.concatenate(parts) else: excerpt[:] = torch.cat(parts) elif kind == "zero": excerpt[:] = 0 else: raise ValueError(f"Unsupported mask operation: {kind}") ================================================ FILE: beat_this/dataset/dataset.py ================================================ import concurrent.futures import itertools import json import re from pathlib import Path import numpy as np import pandas as pd import pytorch_lightning as pl import torch from torch.utils.data import DataLoader, Dataset from beat_this.dataset.augment import ( augment_mask_, augment_pitchtempo, precomputed_augmentation_filenames, ) from beat_this.utils import index_to_framewise from .mmnpz import MemmappedNpzFile class BeatTrackingDataset(Dataset): """ A PyTorch Dataset for beat tracking. This dataset loads preprocessed spectrograms and beat annotations from a given data folder and provides them for training or evaluation. Args: item_names (list of str): A list of dataset items such as "gtzan/gtzan_rock_00099". data_folder (Path or str): The base folder where the data is stored. spect_fps (int, optional): The frames per second of the spectrograms. Defaults to 50. train_length (int, optional): The length of the training sequences in frames. If None the entire piece is used. Defaults to 1500. deterministic (bool, optional): If True, the dataset always returns the same sequence for a given index. Defaults to False. augmentations (dict, optional): A dictionary of data augmentations to apply. Possible keys are "tempo", "pitch", and "mask". Defaults to an empty dictionary. """ def __init__( self, item_names: list[str], data_folder, spect_fps=50, train_length=1500, deterministic=False, augmentations={}, length_based_oversampling_factor=0, ): self.spect_basepath = data_folder / "audio" / "spectrograms" self.annotation_basepath = data_folder / "annotations" self.fps = spect_fps self.train_length = train_length self.deterministic = deterministic self.augmentations = augmentations self.length_based_oversampling_factor = length_based_oversampling_factor datasets = sorted(set(name.split("/", 1)[0] for name in item_names)) # load dataset info self.dataset_info = self._load_dataset_infos(datasets) # load .npz spectrogram bundles, if any self.spects = self._load_spect_bundles(datasets) # load the annotations in parallel with concurrent.futures.ThreadPoolExecutor() as executor: items = executor.map(self._load_dataset_item, item_names) items = [item for item in items if item is not None] if self.length_based_oversampling_factor and self.train_length is not None: # oversample the dataset according to the audio lengths, so that long pieces are sampled more often oversampled_items = [] for item in items: oversampling_factor = np.round( self.length_based_oversampling_factor * len(self._get_spect(item)) / self.train_length ).astype(int) oversampling_factor = max(oversampling_factor, 1) oversampled_items.extend(itertools.repeat(item, oversampling_factor)) print( f"Training set oversampled from {len(items)} to {len(oversampled_items)} excerpts." ) items = oversampled_items self.items = items def _load_dataset_infos(self, datasets): dataset_info = {} for dataset in datasets: with open(self.annotation_basepath / dataset / "info.json") as f: dataset_info[dataset] = json.load(f) return dataset_info def _load_spect_bundles(self, datasets): spects = {} for dataset in datasets: npz_file = (self.spect_basepath / dataset).with_suffix(".npz") if npz_file.exists(): spects[dataset] = MemmappedNpzFile(npz_file) return spects def _load_dataset_item(self, item_name): # stop if not all the augmented audio files are there dataset, remainder = item_name.split("/", 1) for aug_filename in precomputed_augmentation_filenames(self.augmentations): if (f"{remainder}/{aug_filename[:-4]}") not in self.spects.get( dataset, () ) and not (self.spect_basepath / item_name / aug_filename).exists(): print( f"Skipping {item_name} because not all necessary spectrograms are there." ) return # load beat and produce a default if beat values are not found dataset, stem = item_name.split("/", 1) annotation_path = ( self.annotation_basepath / dataset / "annotations" / "beats" / (stem + ".beats") ) beat_annotation = np.loadtxt(annotation_path) if beat_annotation.ndim == 2: beat_time = beat_annotation[:, 0] beat_value = beat_annotation[:, 1].astype(int) else: beat_time = beat_annotation beat_value = np.zeros_like(beat_time, dtype=np.int32) # stop if the annotations that are supposed to be there are not there if self.dataset_info[dataset]["has_downbeats"]: if beat_annotation.ndim != 2: print( f"Skipping {item_name} because it has {beat_annotation.ndim} columns but downbeat is supposed to be there." ) return # create a downbeat mask to handle the case where the downbeat is not annotated downbeat_mask = self.dataset_info[dataset]["has_downbeats"] # take care of different subsections of rwc for the dataset name if dataset == "rwc": dataset = "rwc_" + stem.split("_", 2)[1] return { "spect_path": Path(item_name) / "track.npy", "beat_time": beat_time, "beat_value": beat_value, "downbeat_mask": downbeat_mask, "dataset": dataset, } def _get_spect(self, item): try: dataset, filename = str(item["spect_path"]).split("/", 1) spect = self.spects[dataset][filename[:-4]] except KeyError: spect = np.load(self.spect_basepath / item["spect_path"], mmap_mode="r") return spect def get_frame_count(self, index): """Return number of frames of given item.""" return len(self._get_spect(self.items[index])) def get_beat_count(self, index): """Return number of beats (including downbeats) of given item.""" return len(self.items[index]["beat_time"]) def get_downbeat_count(self, index): """Return number of downbeats of given item.""" return (self.items[index]["beat_value"] == 1).sum() def __len__(self): return len(self.items) def __getitem__(self, index): if isinstance(index, (int, np.int64)): # when index is a single int item = self.items[index] # select a pitch shift and time stretch item = augment_pitchtempo(item, self.augmentations) # load spectrogram spect = self._get_spect(item) # define the excerpt to use original_length = len(spect) if self.train_length is not None: longer = original_length - self.train_length else: longer = 0 if longer > 0: # if the piece is longer than the desired length if self.deterministic: # select the middle of the excerpt start_frame = longer // 2 else: start_frame = np.random.randint(0, longer) end_frame = start_frame + self.train_length else: start_frame = 0 end_frame = original_length # obtain a view of the excerpt spect = spect[start_frame:end_frame] if "mask" in self.augmentations: # copy the spectrogram and apply mask augmentation spect = np.copy(spect) spect = augment_mask_(spect, self.augmentations, self.fps) else: # only ensure we have a writeable array (so PyTorch is happy) spect = np.require(spect, requirements="W") # prepare annotations ( framewise_truth_beat, framewise_truth_downbeat, truth_orig_beat, truth_orig_downbeat, ) = prepare_annotations(item, start_frame, end_frame, self.fps) # restructure the item dict with the correct training information item = { "spect": spect, "spect_path": str(item["spect_path"]), "dataset": item["dataset"], "start_frame": start_frame, "truth_beat": framewise_truth_beat, "truth_downbeat": framewise_truth_downbeat, "downbeat_mask": torch.as_tensor(item["downbeat_mask"]), "padding_mask": ( np.ones(self.train_length, dtype=bool) if self.train_length is not None else np.ones(original_length, dtype=bool) ), "truth_orig_beat": truth_orig_beat, "truth_orig_downbeat": truth_orig_downbeat, } # pad all framewise tensors if needed if longer < 0: item["spect"] = np.pad( item["spect"], [(0, -longer), (0, 0)], constant_values=0 ) for k in "truth_beat", "truth_downbeat": item[k] = np.pad(item[k], [(0, -longer)], constant_values=0) item["padding_mask"][longer:] = 0 return item else: # when index is a list of ints return [self[i] for i in index] class BeatDataModule(pl.LightningDataModule): """ A PyTorch Lightning DataModule for beat tracking. This DataModule handles the loading and preprocessing of the BeatTrackingDataset and prepares it for use with a PyTorch Lightning model. It can produce cross-validation or single train/val/test splits. Args: data_dir (Path or str): The parent directory where the data (spectrograms and beat labels) is stored. batch_size (int, optional): The size of the batches to be generated by the DataLoader. Defaults to 8. train_length (int, optional): The length of the subsequences in frames. If None, the entire pieces are returner. Defaults to 1500. num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 20. augmentations (dict, optional): A dictionary of data augmentations to apply. Defaults to {"pitch": {"min": -5, "max": 6}, "time": {"min": -20, "max": 20, "stride": 4}}. test_dataset (str, optional): The name of the dataset to use for testing. Defaults to "gtzan". hung_data (bool, optional): If True, only use the datasets from the Hung et al. paper for training; validation is still on all datasets. Defaults to False. no_val (bool, optional): If True, train on all train+val data and do not use a validation set; for compatibility reason, the validation metrics are still computed, but are not meaningful. Defaults to False. spect_fps (int, optional): The frames per second of the spectrograms. Defaults to 50. length_based_oversampling_factor (int, optional): The factor by which to oversample the train dataset based on sequence length. Defaults to 0. fold (int, optional): The fold number for cross-validation. If None, the single split is used. Defaults to None. predict_datasplit (str, optional): The split to use for prediction. Prediction dataset is always full pieces. Defaults to "test". """ def __init__( self, data_dir, batch_size=8, train_length=1500, num_workers=20, augmentations={ "pitch": {"min": -5, "max": 6}, "tempo": {"min": -20, "max": 20, "stride": 4}, }, test_dataset="gtzan", hung_data=False, no_val=False, spect_fps=50, length_based_oversampling_factor=0, fold=None, predict_datasplit="test", ): super().__init__() self.save_hyperparameters() self.initialized = {} # remember all arguments self.data_dir = Path(data_dir) self.batch_size = batch_size self.train_length = train_length self.num_workers = num_workers if not set(augmentations.keys()).issubset({"mask", "pitch", "tempo"}): raise ValueError(f"Unsupported augmentations: {augmentations.keys()}") self.augmentations = augmentations self.test_set_name = test_dataset self.hung_data = hung_data self.no_val = no_val self.spect_fps = spect_fps self.length_based_oversampling_factor = length_based_oversampling_factor self.fold = fold self.predict_datasplit = predict_datasplit def setup(self, stage): if self.initialized.get(stage, False): return # set up the paths annotation_dir = self.data_dir / "annotations" # load train/val splits if stage in ("fit", "validate"): self.val_items = [] self.train_items = [] split_file = "8-folds.split" if self.fold is not None else "single.split" for dataset_dir in annotation_dir.iterdir(): if not dataset_dir.is_dir() or not (dataset_dir / split_file).exists(): continue dataset = dataset_dir.name if dataset == self.test_set_name: continue split = pd.read_csv( dataset_dir / split_file, header=None, names=["piece", "part"], sep="\t", ) if self.fold is not None: # CV: use given fold for validation, rest for training self.val_items.extend( f"{dataset}/{stem}" for stem in split.piece[split.part == self.fold] ) self.train_items.extend( f"{dataset}/{stem}" for stem in split.piece[split.part != self.fold] ) else: # single split: marked as val and train self.val_items.extend( f"{dataset}/{stem}" for stem in split.piece[split.part == "val"] ) self.train_items.extend( f"{dataset}/{stem}" for stem in split.piece[split.part == "train"] ) if self.no_val: # Train on all available data (excluding the test set). # For compatibility, validation metrics are still computed # on the original validation set now included in training. self.train_items.extend(self.val_items) if self.hung_data: # Use the training datasets from MODELING BEATS AND DOWNBEATS # WITH A TIME-FREQUENCY TRANSFORMER (for comparability, the # validation set stays the same, with all datasets). regexp = re.compile( "^(hainsworth/|ballroom/|hjdb/|beatles/|rwc/rwc_popular|simac/|smc/|harmonix/|).*$" ) self.train_items = [ item for item in self.train_items if regexp.match(item) ] self.val_items.sort() self.train_items.sort() # load validation set if stage in ("fit", "validate"): self.val_dataset = BeatTrackingDataset( self.val_items, deterministic=True, augmentations={}, train_length=self.train_length, data_folder=self.data_dir, spect_fps=self.spect_fps, ) print( "Validation set:", len(self.val_dataset), "items from:", *sorted(set(item.split("/", 1)[0] for item in self.val_items)), ) self.initialized["validate"] = True # load training set if stage == "fit": self.train_dataset = BeatTrackingDataset( self.train_items, deterministic=False, augmentations=self.augmentations, train_length=self.train_length, data_folder=self.data_dir, spect_fps=self.spect_fps, length_based_oversampling_factor=self.length_based_oversampling_factor, ) print( "Training set:", len(self.train_dataset), "items from:", *sorted(set(item.split("/", 1)[0] for item in self.train_items)), ) self.initialized["fit"] = True # load test set if stage == "test": test_annotations_dir = ( annotation_dir / self.test_set_name / "annotations" / "beats" ) self.test_items = sorted( f"{self.test_set_name}/{item.stem}" for item in test_annotations_dir.glob("*.beats") ) self.test_dataset = BeatTrackingDataset( self.test_items, deterministic=True, augmentations={}, train_length=None, data_folder=self.data_dir, spect_fps=self.spect_fps, ) print( "Test set:", len(self.test_dataset), "items from:", self.test_set_name ) self.initialized["test"] = True # load prediction set if stage == "predict": if self.predict_datasplit == "test": self.setup("test") # we can directly use the test dataset for predictions self.predict_dataset = self.test_dataset else: if self.predict_datasplit == "train": self.setup("fit") items = self.train_items elif self.predict_datasplit == "val": self.setup("validate") items = self.val_items # for prediction, we want to use full items (train_length=None) self.predict_dataset = BeatTrackingDataset( items, deterministic=True, augmentations={}, train_length=None, data_folder=self.data_dir, spect_fps=self.spect_fps, ) def train_dataloader(self): return DataLoader( self.train_dataset, num_workers=self.num_workers, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, ) def val_dataloader(self): # Warning: for performances, this only runs on the middle excerpt of the long pieces # The paper results are computed after training in the predict script return DataLoader( self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers ) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=1, num_workers=self.num_workers) def predict_dataloader(self): return DataLoader( self.predict_dataset, batch_size=1, num_workers=self.num_workers ) def get_train_positive_weights(self, widen_target_mask=3): """ Computes the relation of negative targets to positive targets. `widen_target_mask` reduces the number of negative targets by the given factor times the number of positive targets (for ignoring a number of frames around each positive label). For example a `widen_target_mask` of 3 will ignore 7 frames, 3 for each side plus the central. """ # find the positive weight for the loss as a ratio between (down)beat and non-(down)beat annotation dataset = self.train_dataset all_frames = all_frames_db = 0 for item in dataset.items: frames = len(dataset._get_spect(item)) all_frames += frames if item["downbeat_mask"]: all_frames_db += frames beat_frames = sum(len(item["beat_value"]) for item in dataset.items) downbeat_frames = sum( (item["beat_value"] == 1).sum() for item in dataset.items if item["downbeat_mask"] ) return { "beat": int( np.round( (all_frames - beat_frames * (widen_target_mask * 2 + 1)) / beat_frames ) ), "downbeat": int( np.round( (all_frames_db - downbeat_frames * (widen_target_mask * 2 + 1)) / downbeat_frames ) ), } def prepare_annotations(item, start_frame, end_frame, fps): truth_bdb_time = item["beat_time"] truth_bdb_value = item["beat_value"] # convert beat time from seconds to frame truth_bdb_frame = (truth_bdb_time * fps).round().astype(int) # form annotations excerpt # filter out the annotations that are earlier than the start and shift left truth_bdb_frame -= start_frame idx = np.searchsorted(truth_bdb_frame, 0) truth_bdb_frame = truth_bdb_frame[idx:] truth_bdb_value = truth_bdb_value[idx:] # filter out the annotations that are later than the end idx = np.searchsorted(truth_bdb_frame, end_frame - start_frame) truth_bdb_frame = truth_bdb_frame[:idx] truth_bdb_value = truth_bdb_value[:idx] # create beat and downbeat separated annotations truth_beat = truth_bdb_frame truth_downbeat = truth_bdb_frame[truth_bdb_value == 1] # transform beat downbeat to frame-wise annotations framewise_truth_beat = index_to_framewise(truth_beat, end_frame - start_frame) framewise_truth_downbeat = index_to_framewise( truth_downbeat, end_frame - start_frame ) # create orig beat, downbeat annotations for unquantized evaluation truth_orig_beat = item["beat_time"] truth_orig_downbeat = truth_bdb_time[ item["beat_value"] == 1 ] # (use the full beat_value) # filter out the annotations that are outside the excerpt, and shift them left to the excerpt time truth_orig_beat = truth_orig_beat[ (truth_orig_beat >= start_frame / fps) & (truth_orig_beat < end_frame / fps) ] - (start_frame / fps) truth_orig_downbeat = truth_orig_downbeat[ (truth_orig_downbeat >= start_frame / fps) & (truth_orig_downbeat < end_frame / fps) ] - (start_frame / fps) # convert to strings (trick to collate sequences of different lengths) truth_orig_beat = truth_orig_beat.tobytes() truth_orig_downbeat = truth_orig_downbeat.tobytes() return ( framewise_truth_beat, framewise_truth_downbeat, truth_orig_beat, truth_orig_downbeat, ) ================================================ FILE: beat_this/dataset/mmnpz.py ================================================ """ Support for memory-mapping uncompressed .npz files. """ import struct from collections.abc import Mapping from zipfile import ZipFile import numpy as np class MemmappedNpzFile(Mapping): """ A dictionary-like object with lazy-loading of numpy arrays in the given uncompressed .npz file. Upon construction, creates a memory map of the full .npz file, returning views for the arrays within on request. Attributes ---------- files : list of str List of all uncompressed files in the archive with a ``.npy`` extension (listed without the extension). These are supported as dictionary keys. mmap : np.memmap The memory map of the full .npz file. arrays : dict Preloaded or cached arrays. Parameters ---------- fn : str or Path The zipped archive to open. cache : bool, optional Whether to cache array objects in case they are requested again. preload : bool, optional Whether to precreate all array objects upon opening. Enforces caching. """ def __init__(self, fn: str, cache: bool = True, preload: bool = False): with ZipFile(fn, mode="r") as f: self._offsets = { zinfo.filename[:-4]: (zinfo.header_offset, zinfo.file_size) for zinfo in f.infolist() if zinfo.filename.endswith(".npy") and zinfo.compress_type == 0 } self.files = list(self._offsets.keys()) self.mmap = np.memmap(fn, mode="r") self.cache = cache or preload self.preload = preload if self.preload: self.arrays = {name: self.load(name) for name in self.files} else: self.arrays = {} def load(self, name: str): header_offset, file_size = self._offsets[name] # parse lengths of local header file name and extra fields # (ZipInfo is based on the global directory, not local header) fn_len, extra_len = struct.unpack( "<2H", self.mmap[header_offset + 26 : header_offset + 30] ) # compute offset of start and end of data npy_start = header_offset + 30 + fn_len + extra_len npy_end = npy_start + file_size # read NPY header fp = MemoryviewIO(self.mmap) fp.seek(npy_start) version = np.lib.format.read_magic(fp) np.lib.format._check_version(version) shape, fortran, dtype = np.lib.format._read_array_header(fp, version) # produce slice of memmap data_start = fp.tell() return ( self.mmap[data_start:npy_end] .view(dtype=dtype) .reshape(shape, order="F" if fortran else "C") ) def close(self): if hasattr(self, "mmap"): del self.mmap self.arrays = {} def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def __iter__(self): return iter(self.files) def __len__(self): return len(self.files) def __getitem__(self, key: str): if self.cache: try: return self.arrays[key] except KeyError: pass array = self.load(key) if self.cache: self.arrays[key] = array return array def __contains__(self, key: str): # Mapping.__contains__ calls __getitem__, which could be expensive return key in self._offsets class MemoryviewIO(object): """ Wraps an object supporting the buffer protocol to be a readonly file-like. """ def __init__(self, buffer): self._buffer = memoryview(buffer).cast("B") self._pos = 0 self.seekable = lambda: True self.readable = lambda: True self.writable = lambda: False def seek(self, offset, whence=0): if whence == 0: self._pos = offset elif whence == 1: self._pos += offset elif whence == 2: self._pos = self._buffer.nbytes + offset def read(self, size=-1): data = self._buffer[ self._pos : self._pos + size if size >= 0 else None ].tobytes() self._pos += len(data) return data def tell(self): return self._pos ================================================ FILE: beat_this/inference.py ================================================ import inspect import numpy as np import soxr import torch import torch.nn.functional as F from beat_this.model.beat_tracker import BeatThis from beat_this.model.postprocessor import Postprocessor from beat_this.preprocessing import LogMelSpect, load_audio from beat_this.utils import replace_state_dict_key, save_beat_tsv CHECKPOINT_URL = "https://cloud.cp.jku.at/public.php/dav/files/7ik4RrBKTS273gp" def load_checkpoint(checkpoint_path: str, device: str | torch.device = "cpu") -> dict: """ Load a BeatThis checkpoint as a dictionary. Args: checkpoint_path (str, optional): The path to the checkpoint. Can be a local path, a URL, or a shortname. device (torch.device or str): The device to load the model on. Returns: dict: The loaded checkpoint dictionary. """ try: # try interpreting as local file name weights_only = {"weights_only": True} if torch.__version__ >= "2" else {} return torch.load(checkpoint_path, map_location=device, **weights_only) except FileNotFoundError: try: if not ( str(checkpoint_path).startswith("https://") or str(checkpoint_path).startswith("http://") ): # interpret it as a name of one of our checkpoints checkpoint_url = f"{CHECKPOINT_URL}/{checkpoint_path}.ckpt" file_name = f"beat_this-{checkpoint_path}.ckpt" else: # try interpreting as a URL checkpoint_url = checkpoint_path file_name = None return torch.hub.load_state_dict_from_url( checkpoint_url, file_name=file_name, map_location=device, ) except Exception: raise ValueError( "Could not load the checkpoint given the provided name", checkpoint_path, ) def load_model( checkpoint_path: str | None = "final0", device: str | torch.device = "cpu" ) -> BeatThis: """ Load a BeatThis model from a checkpoint. Args: checkpoint_path (str, optional): The path to the checkpoint. Can be a local path, a URL, or a shortname. device (torch.device or str): The device to load the model on. Returns: BeatThis: The loaded model. """ if checkpoint_path is not None: checkpoint = load_checkpoint(checkpoint_path, device) # Retrieve the model hyperparameters as it could be the small model hparams = checkpoint["hyper_parameters"] # Filter only those hyperparameters that apply to the model itself hparams = { k: v for k, v in hparams.items() if k in set(inspect.signature(BeatThis).parameters) } # Create the uninitialized model model = BeatThis(**hparams) # The PLBeatThis (LightningModule) state_dict contains the BeatThis # state_dict under the "model." prefix; remove the prefix to load it state_dict = replace_state_dict_key(checkpoint["state_dict"], "model.", "") model.load_state_dict(state_dict) else: model = BeatThis() return model.to(device).eval() def zeropad(spect: torch.Tensor, left: int = 0, right: int = 0): """ Pads a tensor spectrogram matrix of shape (time x bins) with `left` frames in the beginning and `right` frames in the end. """ if left == 0 and right == 0: return spect else: return F.pad(spect, (0, 0, left, right), "constant", 0) def split_piece( spect: torch.Tensor, chunk_size: int, border_size: int = 6, avoid_short_end: bool = True, ): """ Split a tensor spectrogram matrix of shape (time x bins) into time chunks of `chunk_size` and return the chunks and starting positions. The `border_size` is the number of frames assumed to be discarded in the predictions on either side (since the model was not trained on the input edges due to the max-pool in the loss). To cater for this, the first and last chunk are padded by `border_size` on the beginning and end, respectively, and consecutive chunks overlap by `border_size`. If `avoid_short_end` is true, the last chunk start is shifted left to ends at the end of the piece, therefore the last chunk can potentially overlap with previous chunks more than border_size, otherwise it will be a shorter segment. If the piece is shorter than `chunk_size`, avoid_short_end is ignored and the piece is returned as a single shorter chunk. Args: spect (torch.Tensor): The input spectrogram tensor of shape (time x bins). chunk_size (int): The size of the chunks to produce. border_size (int, optional): The size of the border to overlap between chunks. Defaults to 6. avoid_short_end (bool, optional): If True, the last chunk is shifted left to end at the end of the piece. Defaults to True. """ # generate the start and end indices starts = np.arange( -border_size, len(spect) - border_size, chunk_size - 2 * border_size ) if avoid_short_end and len(spect) > chunk_size - 2 * border_size: # if we avoid short ends, move the last index to the end of the piece - (chunk_size - border_size) starts[-1] = len(spect) - (chunk_size - border_size) # generate the chunks chunks = [ zeropad( spect[max(start, 0) : min(start + chunk_size, len(spect))], left=max(0, -start), right=max(0, min(border_size, start + chunk_size - len(spect))), ) for start in starts ] return chunks, starts def aggregate_prediction( pred_chunks: list, starts: list, full_size: int, chunk_size: int, border_size: int, overlap_mode: str, device: str | torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """ Aggregates the predictions for the whole piece based on the given prediction chunks. Args: pred_chunks (list): List of prediction chunks, where each chunk is a dictionary containing 'beat' and 'downbeat' predictions. starts (list): List of start positions for each prediction chunk. full_size (int): Size of the full piece. chunk_size (int): Size of each prediction chunk. border_size (int): Size of the border to be discarded from each prediction chunk. overlap_mode (str): Mode for handling overlapping predictions. Can be 'keep_first' or 'keep_last'. device (torch.device): Device to be used for the predictions. Returns: tuple: A tuple containing the aggregated beat predictions and downbeat predictions as torch tensors for the whole piece. """ if border_size > 0: # cut the predictions to discard the border pred_chunks = [ { "beat": pchunk["beat"][border_size:-border_size], "downbeat": pchunk["downbeat"][border_size:-border_size], } for pchunk in pred_chunks ] # aggregate the predictions for the whole piece piece_prediction_beat = torch.full((full_size,), -1000.0, device=device) piece_prediction_downbeat = torch.full((full_size,), -1000.0, device=device) if overlap_mode == "keep_first": # process in reverse order, so predictions of earlier excerpts overwrite later ones pred_chunks = reversed(list(pred_chunks)) starts = reversed(list(starts)) for start, pchunk in zip(starts, pred_chunks): piece_prediction_beat[ start + border_size : start + chunk_size - border_size ] = pchunk["beat"] piece_prediction_downbeat[ start + border_size : start + chunk_size - border_size ] = pchunk["downbeat"] return piece_prediction_beat, piece_prediction_downbeat def split_predict_aggregate( spect: torch.Tensor, chunk_size: int, border_size: int, overlap_mode: str, model: torch.nn.Module, ) -> dict: """ Function for pieces that are longer than the training length of the model. Split the input piece into chunks, run the model on them, and aggregate the predictions. The spect is supposed to be a torch tensor of shape (time x bins), i.e., unbatched, and the output is also unbatched. Args: spect (torch.Tensor): the input piece chunk_size (int): the length of the chunks border_size (int): the size of the border that is discarded from the predictions overlap_mode (str): how to handle overlaps between chunks model (torch.nn.Module): the model to run Returns: dict: the model framewise predictions for the hole piece as a dictionary containing 'beat' and 'downbeat' predictions. """ # split the piece into chunks chunks, starts = split_piece( spect, chunk_size, border_size=border_size, avoid_short_end=True ) # run the model pred_chunks = [model(chunk.unsqueeze(0)) for chunk in chunks] # remove the extra dimension in beat and downbeat prediction due to batch size 1 pred_chunks = [ {"beat": p["beat"][0], "downbeat": p["downbeat"][0]} for p in pred_chunks ] piece_prediction_beat, piece_prediction_downbeat = aggregate_prediction( pred_chunks, starts, spect.shape[0], chunk_size, border_size, overlap_mode, spect.device, ) # save it to model_prediction return {"beat": piece_prediction_beat, "downbeat": piece_prediction_downbeat} class Spect2Frames: """ Class for extracting framewise beat and downbeat predictions (logits) from a spectrogram. """ def __init__(self, checkpoint_path="final0", device="cpu", float16=False): super().__init__() self.device = torch.device(device) self.float16 = float16 self.model = load_model(checkpoint_path, self.device) def spect2frames(self, spect): with torch.inference_mode(): with torch.autocast(enabled=self.float16, device_type=self.device.type): model_prediction = split_predict_aggregate( spect=spect, chunk_size=1500, overlap_mode="keep_first", border_size=6, model=self.model, ) return model_prediction["beat"].float(), model_prediction["downbeat"].float() def __call__(self, spect): return self.spect2frames(spect) class Audio2Frames(Spect2Frames): """ Class for extracting framewise beat and downbeat predictions (logits) from an audio tensor. """ def __init__(self, checkpoint_path="final0", device="cpu", float16=False): super().__init__(checkpoint_path, device, float16) self.spect = LogMelSpect(device=self.device) def signal2spect(self, signal, sr): if signal.ndim == 2: signal = signal.mean(1) elif signal.ndim != 1: raise ValueError(f"Expected 1D or 2D signal, got shape {signal.shape}") if sr != 22050: signal = soxr.resample(signal, in_rate=sr, out_rate=22050) signal = torch.tensor(signal, dtype=torch.float32, device=self.device) return self.spect(signal) def __call__(self, signal, sr): spect = self.signal2spect(signal, sr) return self.spect2frames(spect) class Audio2Beats(Audio2Frames): """ Class for extracting beat and downbeat positions (in seconds) from an audio tensor. Args: checkpoint_path (str): Path to the model checkpoint file. It can be a local path, a URL, or a key from the CHECKPOINT_URL dictionary. Default is "final0", which will load the model trained on all data except GTZAN with seed 0. device (str): Device to use for inference. Default is "cpu". float16 (bool): Whether to use half precision floating point arithmetic. Default is False. dbn (bool): Whether to use the madmom DBN for post-processing. Default is False. """ def __init__( self, checkpoint_path="final0", device="cpu", float16=False, dbn=False ): super().__init__(checkpoint_path, device, float16) self.frames2beats = Postprocessor(type="dbn" if dbn else "minimal") def __call__(self, signal, sr): beat_logits, downbeat_logits = super().__call__(signal, sr) return self.frames2beats(beat_logits, downbeat_logits) class File2Beats(Audio2Beats): def __call__(self, audio_path): signal, sr = load_audio(audio_path) return super().__call__(signal, sr) class File2File(File2Beats): def __call__(self, audio_path, output_path): downbeats, beats = super().__call__(audio_path) save_beat_tsv(downbeats, beats, output_path) ================================================ FILE: beat_this/model/__init__.py ================================================ ================================================ FILE: beat_this/model/beat_tracker.py ================================================ """ Model definitions for the Beat This! beat tracker. """ import contextlib from collections import OrderedDict import torch from einops import rearrange from einops.layers.torch import Rearrange from rotary_embedding_torch import RotaryEmbedding from torch import nn from beat_this.model import roformer from beat_this.utils import replace_state_dict_key class BeatThis(nn.Module): """ A neural network model for beat tracking. It is composed of three main components: - a frontend that processes the input spectrogram, - a series of transformer blocks that process the output of the frontend, - a head that produces the final beat and downbeat predictions. Args: spect_dim (int): The dimension of the input spectrogram (default: 128). transformer_dim (int): The dimension of the main transformer blocks (default: 512). ff_mult (int): The multiplier for the feed-forward dimension in the transformer blocks (default: 4). n_layers (int): The number of transformer blocks (default: 6). head_dim (int): The dimension of each attention head for the partial transformers in the frontend and the transformer blocks (default: 32). stem_dim (int): The out dimension of the stem convolutional layer (default: 32). dropout (dict): A dictionary specifying the dropout rates for different parts of the model (default: {"frontend": 0.1, "transformer": 0.2}). sum_head (bool): Whether to use a SumHead for the final predictions (default: True) or plain independent projections. partial_transformers (bool): Whether to include partial frequency- and time-transformers in the frontend (default: True) """ def __init__( self, spect_dim: int = 128, transformer_dim: int = 512, ff_mult: int = 4, n_layers: int = 6, head_dim: int = 32, stem_dim: int = 32, dropout: dict = {"frontend": 0.1, "transformer": 0.2}, sum_head: bool = True, partial_transformers: bool = True, ): super().__init__() # shared rotary embedding for frontend blocks and transformer blocks rotary_embed = RotaryEmbedding(head_dim) # create the frontend # - stem stem = self.make_stem(spect_dim, stem_dim) spect_dim //= 4 # frequencies were convolved with stride 4 # - three frontend blocks frontend_blocks = [] dim = stem_dim for _ in range(3): frontend_blocks.append( self.make_frontend_block( dim, dim * 2, partial_transformers, head_dim, rotary_embed, dropout["frontend"], ) ) dim *= 2 spect_dim //= 2 # frequencies were convolved with stride 2 frontend_blocks = nn.Sequential(*frontend_blocks) # - linear projection to transformer dimensionality concat = Rearrange("b c f t -> b t (c f)") linear = nn.Linear(dim * spect_dim, transformer_dim) self.frontend = nn.Sequential( OrderedDict(stem=stem, blocks=frontend_blocks, concat=concat, linear=linear) ) # create the transformer blocks assert ( transformer_dim % head_dim == 0 ), "transformer_dim must be divisible by head_dim" n_heads = transformer_dim // head_dim self.transformer_blocks = roformer.Transformer( dim=transformer_dim, depth=n_layers, heads=n_heads, attn_dropout=dropout["transformer"], ff_dropout=dropout["transformer"], rotary_embed=rotary_embed, ff_mult=ff_mult, dim_head=head_dim, norm_output=True, ) # create the output heads if sum_head: self.task_heads = SumHead(transformer_dim) else: self.task_heads = Head(transformer_dim) # init all weights self.apply(self._init_weights) @staticmethod def make_stem(spect_dim: int, stem_dim: int) -> nn.Module: return nn.Sequential( OrderedDict( rearrange_tf=Rearrange("b t f -> b f t"), bn1d=nn.BatchNorm1d(spect_dim), add_channel=Rearrange("b f t -> b 1 f t"), conv2d=nn.Conv2d( in_channels=1, out_channels=stem_dim, kernel_size=(4, 3), stride=(4, 1), padding=(0, 1), bias=False, ), bn2d=nn.BatchNorm2d(stem_dim), activation=nn.GELU(), ) ) @staticmethod def make_frontend_block( in_dim: int, out_dim: int, partial_transformers: bool = True, head_dim: int | None = 32, rotary_embed: RotaryEmbedding | None = None, dropout: float = 0.1, ) -> nn.Module: if partial_transformers and (head_dim is None or rotary_embed is None): raise ValueError( "Must specify head_dim and rotary_embed for using partial_transformers" ) return nn.Sequential( OrderedDict( partial=( PartialFTTransformer( dim=in_dim, dim_head=head_dim, n_head=in_dim // head_dim, rotary_embed=rotary_embed, dropout=dropout, ) if partial_transformers else nn.Identity() ), # conv block conv2d=nn.Conv2d( in_channels=in_dim, out_channels=out_dim, kernel_size=(2, 3), stride=(2, 1), padding=(0, 1), bias=False, ), # out_channels : 64, 128, 256 # freqs : 16, 8, 4 (due to the stride=2) norm=nn.BatchNorm2d(out_dim), activation=nn.GELU(), ) ) @staticmethod def _init_weights(module: nn.Module): if isinstance(module, (nn.Linear, nn.Conv1d)): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): torch.nn.init.kaiming_normal_( module.weight, mode="fan_out", nonlinearity="relu" ) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: with torch.no_grad(): module.weight[module.padding_idx].fill_(0) def forward(self, x): x = self.frontend(x) x = self.transformer_blocks(x) x = self.task_heads(x) return x def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # remove _orig_mod prefixes for compiled models state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "") super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) # remove _orig_mod prefixes for compiled models state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "") return state_dict class PartialRoformer(nn.Module): """ Takes a (batch, channels, freqs, time) input, applies self-attention and a feed-forward block either only across frequencies or only across time. Returns a tensor of the same shape as the input. """ def __init__( self, dim: int, dim_head: int, n_head: int, direction: str, rotary_embed: RotaryEmbedding, dropout: float, ): super().__init__() assert dim % dim_head == 0, "dim must be divisible by dim_head" assert dim // dim_head == n_head, "n_head must be equal to dim // dim_head" self.direction = direction[0].lower() if self.direction not in "ft": raise ValueError(f"direction must be F or T, got {direction}") self.attn = roformer.Attention( dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed, ) self.ff = roformer.FeedForward(dim, dropout=dropout) def forward(self, x): b = len(x) if self.direction == "f": pattern = "(b t) f c" elif self.direction == "t": pattern = "(b f) t c" x = rearrange(x, f"b c f t -> {pattern}") x = x + self.attn(x) x = x + self.ff(x) x = rearrange(x, f"{pattern} -> b c f t", b=b) return x class PartialFTTransformer(nn.Module): """ Takes a (batch, channels, freqs, time) input, applies self-attention and a feed-forward block once across frequencies and once across time. Same as applying two PartialRoformer() in sequence, but encapsulated in a single module. Returns a tensor of the same shape as the input. """ def __init__( self, dim: int, dim_head: int, n_head: int, rotary_embed: RotaryEmbedding, dropout: float, ): super().__init__() assert dim % dim_head == 0, "dim must be divisible by dim_head" assert dim // dim_head == n_head, "n_head must be equal to dim // dim_head" # frequency directed partial transformer self.attnF = roformer.Attention( dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed, ) self.ffF = roformer.FeedForward(dim, dropout=dropout) # time directed partial transformer self.attnT = roformer.Attention( dim, heads=n_head, dim_head=dim_head, dropout=dropout, rotary_embed=rotary_embed, ) self.ffT = roformer.FeedForward(dim, dropout=dropout) def forward(self, x): b = len(x) # frequency directed partial transformer x = rearrange(x, "b c f t -> (b t) f c") x = x + self.attnF(x) x = x + self.ffF(x) # time directed partial transformer x = rearrange(x, "(b t) f c ->(b f) t c", b=b) x = x + self.attnT(x) x = x + self.ffT(x) x = rearrange(x, "(b f) t c -> b c f t", b=b) return x class SumHead(nn.Module): """ A PyTorch module that produces the final beat and downbeat prediction logits. The beats are a sum of all beats and all downbeats predictions, to reduce the prediction of downbeats which are not beats. """ def __init__(self, input_dim): super().__init__() self.beat_downbeat_lin = nn.Linear(input_dim, 2) def forward(self, x): beat_downbeat = self.beat_downbeat_lin(x) # separate beat from downbeat beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2) # aggregate beats and downbeats prediction # autocast to float16 disabled to avoid numerical issues causing NaNs if hasattr( torch.amp, "is_autocast_available" ) and not torch.amp.is_autocast_available(beat.device.type): # but do not try disabling if the device does not support autocast disable_autocast = contextlib.nullcontext() else: disable_autocast = torch.autocast(beat.device.type, enabled=False) with disable_autocast: beat = beat.float() + downbeat.float() return {"beat": beat, "downbeat": downbeat} class Head(nn.Module): """ A PyToch module that produces the final beat and downbeat prediction logits with independent linear layers outputs. """ def __init__(self, input_dim): super().__init__() self.beat_downbeat_lin = nn.Linear(input_dim, 2) def forward(self, x): beat_downbeat = self.beat_downbeat_lin(x) # separate beat from downbeat beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2) return {"beat": beat, "downbeat": downbeat} ================================================ FILE: beat_this/model/loss.py ================================================ """ Loss definitions for the Beat This! beat tracker. """ import torch import torch.nn.functional as F class MaskedBCELoss(torch.nn.Module): """ Plain binary cross-entropy loss. Expects predictions to be given as logits, and accepts an optional mask with zeros indicating the entries to ignore. Args: pos_weight (float): Weight for positive examples compared to negative examples (default: 1) """ def __init__(self, pos_weight: float = 1): super().__init__() self.register_buffer( "pos_weight", torch.tensor(pos_weight, dtype=torch.get_default_dtype()), persistent=False, ) def forward( self, preds: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor | None = None, ): return F.binary_cross_entropy_with_logits( preds, targets, weight=mask, pos_weight=self.pos_weight ) class ShiftTolerantBCELoss(torch.nn.Module): """ BCE loss variant for sequence labeling that tolerates small shifts between predictions and targets. This is accomplished by max-pooling the predictions with a given tolerance and a stride of 1, so the gradient for a positive label affects the largest prediction in a window around it. Expects predictions to be given as logits, and accepts an optional mask with zeros indicating the entries to ignore. Note that the edges of the sequence will not receive a gradient, as it is assumed to be unknown whether there is a nearby positive annotation. Args: pos_weight (float): Weight for positive examples compared to negative examples (default: 1) tolerance (int): Tolerated shift in time steps in each direction (default: 3) """ def __init__(self, pos_weight: float = 1, tolerance: int = 3): super().__init__() self.register_buffer( "pos_weight", torch.tensor(pos_weight, dtype=torch.get_default_dtype()), persistent=False, ) self.tolerance = tolerance def spread(self, x: torch.Tensor, factor: int = 1): if self.tolerance == 0: return x return F.max_pool1d(x, 1 + 2 * factor * self.tolerance, 1) def crop(self, x: torch.Tensor, factor: int = 1): return x[..., factor * self.tolerance : -factor * self.tolerance or None] def forward( self, preds: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor | None = None, ): # spread preds and crop targets to match spreaded_preds = self.crop(self.spread(preds)) cropped_targets = self.crop(targets, factor=2) # ignore around the positive targets look_at = cropped_targets + (1 - self.spread(targets, factor=2)) if mask is not None: # consider padding and no-downbeat mask look_at = look_at * self.crop(mask, factor=2) # compute loss return F.binary_cross_entropy_with_logits( spreaded_preds, cropped_targets, weight=look_at, pos_weight=self.pos_weight, ) class SplittedShiftTolerantBCELoss(torch.nn.Module): """ Alternative implementation of ShiftTolerantBCELoss that splits the loss for positive and negative targets. This is mainly provided as it may be a bit easier to understand and compare with the Beat This! paper. Note that for non-binary targets (e.g., with label smoothing), this implementation matches the equation in the paper (Section 3.3), while ShiftTolerantBCELoss deviates from it. For binary targets, the results are identical. Args: pos_weight (int): weight of positive targets spread_preds (int): amount of temporal max-pooling applied to predictions """ def __init__(self, pos_weight: float = 1, tolerance: int = 3): super().__init__() self.tolerance = 3 self.spread_preds = tolerance self.spread_targets = 2 * tolerance # targets are always spreaded twice as much self.register_buffer( "pos_weight", torch.tensor(pos_weight, dtype=torch.get_default_dtype()), persistent=False, ) def spread(self, x: torch.Tensor, amount: int): if amount: return F.max_pool1d(x, 1 + 2 * amount, 1) else: return x def crop(self, x: torch.Tensor, desired_length: int): amount = (x.shape[-1] - desired_length) // 2 if amount > 0: return x[..., amount:-amount] elif amount == 0: return x else: raise ValueError("Desired length must be smaller than input length") def forward(self, preds: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor): output_length = targets.size(-1) - 2 * self.spread_targets # compute loss for positive targets, we spread preds preds = self.spread(preds, self.spread_preds) # we crop preds and targets (and mask) to ignore problems at the edges due to the maxpool operation cropped_preds = self.crop(preds, output_length) cropped_targets = self.crop(targets, output_length) cropped_mask = self.crop(mask, output_length) loss_positive = F.binary_cross_entropy_with_logits( cropped_preds, cropped_targets, weight=cropped_targets * cropped_mask, pos_weight=self.pos_weight, ) # compute loss for negative targets, we spread targets and preds (already spreaded above) targets = self.spread(targets, self.spread_targets) cropped_targets = self.crop(targets, output_length) loss_negative = F.binary_cross_entropy_with_logits( cropped_preds, cropped_targets, weight=(1 - cropped_targets) * cropped_mask, pos_weight=self.pos_weight, # ensures identical results to the other implementation ) # sum the two losses return loss_positive + loss_negative ================================================ FILE: beat_this/model/pl_module.py ================================================ """ Pytorch Lightning module, wraps a BeatThis model along with losses, metrics and optimizers for training. """ from concurrent.futures import ThreadPoolExecutor from typing import Any import mir_eval import numpy as np import torch from pytorch_lightning import LightningModule import beat_this.model.loss from beat_this.inference import split_predict_aggregate from beat_this.model.beat_tracker import BeatThis from beat_this.model.postprocessor import Postprocessor from beat_this.utils import replace_state_dict_key class PLBeatThis(LightningModule): def __init__( self, spect_dim=128, fps=50, transformer_dim=512, ff_mult=4, n_layers=6, stem_dim=32, dropout={"frontend": 0.1, "transformer": 0.2}, lr=0.0008, weight_decay=0.01, pos_weights={"beat": 1, "downbeat": 1}, head_dim=32, loss_type="shift_tolerant_weighted_bce", warmup_steps=1000, max_epochs=100, use_dbn=False, eval_trim_beats=5, sum_head=True, partial_transformers=True, ): super().__init__() self.save_hyperparameters() self.lr = lr self.weight_decay = weight_decay self.fps = fps # create model self.model = BeatThis( spect_dim=spect_dim, transformer_dim=transformer_dim, ff_mult=ff_mult, stem_dim=stem_dim, n_layers=n_layers, head_dim=head_dim, dropout=dropout, sum_head=sum_head, partial_transformers=partial_transformers, ) self.warmup_steps = warmup_steps self.max_epochs = max_epochs # set up the losses self.pos_weights = pos_weights if loss_type == "shift_tolerant_weighted_bce": self.beat_loss = beat_this.model.loss.ShiftTolerantBCELoss( pos_weight=pos_weights["beat"] ) self.downbeat_loss = beat_this.model.loss.ShiftTolerantBCELoss( pos_weight=pos_weights["downbeat"] ) elif loss_type == "weighted_bce": self.beat_loss = beat_this.model.loss.MaskedBCELoss( pos_weight=pos_weights["beat"] ) self.downbeat_loss = beat_this.model.loss.MaskedBCELoss( pos_weight=pos_weights["downbeat"] ) elif loss_type == "bce": self.beat_loss = beat_this.model.loss.MaskedBCELoss() self.downbeat_loss = beat_this.model.loss.MaskedBCELoss() elif loss_type == "splitted_shift_tolerant_weighted_bce": self.beat_loss = beat_this.model.loss.SplittedShiftTolerantBCELoss( pos_weight=pos_weights["beat"] ) self.downbeat_loss = beat_this.model.loss.SplittedShiftTolerantBCELoss( pos_weight=pos_weights["downbeat"] ) else: raise ValueError( "loss_type must be one of 'shift_tolerant_weighted_bce', 'weighted_bce', 'bce'" ) self.postprocessor = Postprocessor( type="dbn" if use_dbn else "minimal", fps=fps ) self.eval_trim_beats = eval_trim_beats self.metrics = Metrics(eval_trim_beats=eval_trim_beats) def _compute_loss(self, batch, model_prediction): beat_mask = batch["padding_mask"] beat_loss = self.beat_loss( model_prediction["beat"], batch["truth_beat"].float(), beat_mask ) # downbeat mask considers padding and also pieces which don't have downbeat annotations downbeat_mask = beat_mask * batch["downbeat_mask"][:, None] downbeat_loss = self.downbeat_loss( model_prediction["downbeat"], batch["truth_downbeat"].float(), downbeat_mask ) # sum the losses and return them in a dictionary for logging return { "beat": beat_loss, "downbeat": downbeat_loss, "total": beat_loss + downbeat_loss, } def _compute_metrics(self, batch, postp_beat, postp_downbeat, step="val"): """ """ # compute for beat metrics_beat = self._compute_metrics_target( batch, postp_beat, target="beat", step=step ) # compute for downbeat metrics_downbeat = self._compute_metrics_target( batch, postp_downbeat, target="downbeat", step=step ) # concatenate dictionaries metrics = {**metrics_beat, **metrics_downbeat} return metrics def _compute_metrics_target(self, batch, postp_target, target, step): def compute_item(pospt_pred, truth_orig_target): # take the ground truth from the original version, so there are no quantization errors piece_truth_time = np.frombuffer(truth_orig_target) # run evaluation metrics = self.metrics(piece_truth_time, pospt_pred, step=step) return metrics # if the input was not batched, postp_target is an array instead of a tuple of arrays # make it a tuple for consistency if not isinstance(postp_target, tuple): postp_target = (postp_target,) with ThreadPoolExecutor() as executor: piecewise_metrics = list( executor.map( compute_item, postp_target, batch[f"truth_orig_{target}"], ) ) # average the beat metrics across the dictionary batch_metric = { key + f"_{target}": np.mean([x[key] for x in piecewise_metrics]) for key in piecewise_metrics[0].keys() } return batch_metric def log_losses(self, losses, batch_size, step="train"): # log for separate targets for target in "beat", "downbeat": self.log( f"{step}_loss_{target}", losses[target].item(), prog_bar=False, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True, ) # log total loss self.log( f"{step}_loss", losses["total"].item(), prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True, ) def log_metrics(self, metrics, batch_size, step="val"): for key, value in metrics.items(): self.log( f"{step}_{key}", value, prog_bar=key.startswith("F-measure"), on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True, ) def training_step(self, batch, batch_idx): # run the model model_prediction = self.model(batch["spect"]) # compute loss losses = self._compute_loss(batch, model_prediction) self.log_losses(losses, len(batch["spect"]), "train") return losses["total"] def validation_step(self, batch, batch_idx): # run the model model_prediction = self.model(batch["spect"]) # compute loss losses = self._compute_loss(batch, model_prediction) # postprocess the predictions postp_beat, postp_downbeat = self.postprocessor( model_prediction["beat"], model_prediction["downbeat"], batch["padding_mask"], ) # compute the metrics metrics = self._compute_metrics(batch, postp_beat, postp_downbeat, step="val") # log self.log_losses(losses, len(batch["spect"]), "val") self.log_metrics(metrics, batch["spect"].shape[0], "val") def test_step(self, batch, batch_idx): metrics, model_prediction, _, _ = self.predict_step(batch, batch_idx) losses = self._compute_loss(batch, model_prediction) # log self.log_losses(losses, len(batch["spect"]), "test") self.log_metrics(metrics, batch["spect"].shape[0], "test") def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0, chunk_size: int = 1500, overlap_mode: str = "keep_first", ) -> Any: """ Compute predictions and metrics for a batch (a dictionary with an "spect" key). It splits up the audio into multiple chunks of chunk size, which should correspond to the length of the sequence the model was trained with. Potential overlaps between chunks can be handled in two ways: by keeping the predictions of the excerpt coming first (overlap_mode='keep_first'), or by keeping the predictions of the excerpt coming last (overlap_mode='keep_last'). Note that overlaps appear as the last excerpt is moved backwards when it would extend over the end of the piece. """ if batch["spect"].shape[0] != 1: raise ValueError( "When predicting full pieces, only `batch_size=1` is supported" ) if torch.any(~batch["padding_mask"]): raise ValueError( "When predicting full pieces, the Dataset must not pad inputs" ) # compute border size according to the loss type if hasattr( self.beat_loss, "tolerance" ): # discard the edges that are affected by the max-pooling in the loss border_size = 2 * self.beat_loss.tolerance else: border_size = 0 model_prediction = split_predict_aggregate( batch["spect"][0], chunk_size, border_size, overlap_mode, self.model ) # add the batch dimension back in the prediction for consistency model_prediction = { key: value.unsqueeze(0) for key, value in model_prediction.items() } # postprocess the predictions postp_beat, postp_downbeat = self.postprocessor( model_prediction["beat"], model_prediction["downbeat"], None ) # compute the metrics metrics = self._compute_metrics(batch, postp_beat, postp_downbeat, step="test") return metrics, model_prediction, batch["dataset"], batch["spect_path"] def configure_optimizers(self): optimizer = torch.optim.AdamW # only decay 2+-dimensional tensors, to exclude biases and norms # (filtering on dimensionality idea taken from Kaparthy's nano-GPT) params = [ { "params": ( p for p in self.parameters() if p.requires_grad and p.ndim >= 2 ), "weight_decay": self.weight_decay, }, { "params": ( p for p in self.parameters() if p.requires_grad and p.ndim <= 1 ), "weight_decay": 0, }, ] optimizer = optimizer(params, lr=self.lr) self.lr_scheduler = CosineWarmupScheduler( optimizer, self.warmup_steps, self.trainer.estimated_stepping_batches ) result = dict(optimizer=optimizer) result["lr_scheduler"] = {"scheduler": self.lr_scheduler, "interval": "step"} return result def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # remove _orig_mod prefixes for compiled models state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "") super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) # remove _orig_mod prefixes for compiled models state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "") return state_dict class Metrics: def __init__(self, eval_trim_beats: int) -> None: self.min_beat_time = eval_trim_beats def __call__(self, truth, preds, step) -> Any: truth = mir_eval.beat.trim_beats(truth, min_beat_time=self.min_beat_time) preds = mir_eval.beat.trim_beats(preds, min_beat_time=self.min_beat_time) if ( step == "val" ): # limit the metrics that are computed during validation to speed up training fmeasure = mir_eval.beat.f_measure(truth, preds) cemgil = mir_eval.beat.cemgil(truth, preds) return {"F-measure": fmeasure, "Cemgil": cemgil} elif step == "test": # compute all metrics during testing CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(truth, preds) fmeasure = mir_eval.beat.f_measure(truth, preds) cemgil = mir_eval.beat.cemgil(truth, preds) return {"F-measure": fmeasure, "Cemgil": cemgil, "CMLt": CMLt, "AMLt": AMLt} else: raise ValueError("step must be either val or test") class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): """ Cosine annealing over `max_iters` steps with `warmup` linear warmup steps. Optionally re-raises the learning rate for the final `raise_last` fraction of total training time to `raise_to` of the full learning rate, again with a linear warmup (useful for stochastic weight averaging). """ def __init__(self, optimizer, warmup, max_iters, raise_last=0, raise_to=0.5): self.warmup = warmup self.max_num_iters = int((1 - raise_last) * max_iters) self.raise_to = raise_to super().__init__(optimizer) def get_lr(self): lr_factor = self.get_lr_factor(step=self.last_epoch) return [base_lr * lr_factor for base_lr in self.base_lrs] def get_lr_factor(self, step): if step < self.max_num_iters: progress = step / self.max_num_iters lr_factor = 0.5 * (1 + np.cos(np.pi * progress)) if step <= self.warmup: lr_factor *= step / self.warmup else: progress = (step - self.max_num_iters) / self.warmup lr_factor = self.raise_to * min(progress, 1) return lr_factor ================================================ FILE: beat_this/model/postprocessor.py ================================================ from concurrent.futures import ThreadPoolExecutor import numpy as np import torch import torch.nn.functional as F from einops import rearrange class Postprocessor: """Postprocessor for the beat and downbeat predictions of the model. The postprocessor takes the (framewise) model predictions (beat and downbeats) and the padding mask, and returns the postprocessed beat and downbeat as list of times in seconds. The beats and downbeats can be 1D arrays (for only 1 piece) or 2D arrays, if a batch of pieces is considered. The output dimensionality is the same as the input dimensionality. Two types of postprocessing are implemented: - minimal: a simple postprocessing that takes the maximum of the framewise predictions, and removes adjacent peaks. - dbn: a postprocessing based on the Dynamic Bayesian Network proposed by Böck et al. Args: type (str): the type of postprocessing to apply. Either "minimal" or "dbn". Default is "minimal". fps (int): the frames per second of the model framewise predictions. Default is 50. """ def __init__(self, type: str = "minimal", fps: int = 50): assert type in ["minimal", "dbn"] self.type = type self.fps = fps if type == "dbn": from madmom.features.downbeats import DBNDownBeatTrackingProcessor self.dbn = DBNDownBeatTrackingProcessor( beats_per_bar=[3, 4], min_bpm=55.0, max_bpm=215.0, fps=self.fps, transition_lambda=100, ) def __call__( self, beat: torch.Tensor, downbeat: torch.Tensor, padding_mask: torch.Tensor | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ Apply postprocessing to the input beat and downbeat tensors. Works with batched and unbatched inputs. The output is a list of times in seconds, or a list of lists of times in seconds, if the input is batched. Args: beat (torch.Tensor): The input beat tensor. downbeat (torch.Tensor): The input downbeat tensor. padding_mask (torch.Tensor, optional): The padding mask tensor. Defaults to None. Returns: torch.Tensor: The postprocessed beat tensor. torch.Tensor: The postprocessed downbeat tensor. """ batched = False if beat.ndim == 1 else True if padding_mask is None: padding_mask = torch.ones_like(beat, dtype=torch.bool) # if beat and downbeat are 1D tensors, add a batch dimension if not batched: beat = beat.unsqueeze(0) downbeat = downbeat.unsqueeze(0) padding_mask = padding_mask.unsqueeze(0) if self.type == "minimal": postp_beat, postp_downbeat = self.postp_minimal( beat, downbeat, padding_mask ) elif self.type == "dbn": postp_beat, postp_downbeat = self.postp_dbn(beat, downbeat, padding_mask) else: raise ValueError("Invalid postprocessing type") # remove the batch dimension if it was added if not batched: postp_beat = postp_beat[0] postp_downbeat = postp_downbeat[0] # update the model prediction dict return postp_beat, postp_downbeat def postp_minimal(self, beat, downbeat, padding_mask): # concatenate beat and downbeat in the same tensor of shape (B, T, 2) packed_pred = rearrange( [beat, downbeat], "c b t -> b t c", b=beat.shape[0], t=beat.shape[1], c=2 ) # set padded elements to -1000 (= probability zero even in float64) so they don't influence the maxpool pred_logits = packed_pred.masked_fill(~padding_mask.unsqueeze(-1), -1000) # reshape to (2*B, T) to apply max pooling pred_logits = rearrange(pred_logits, "b t c -> (c b) t") # pick maxima within +/- 70ms pred_peaks = pred_logits.masked_fill( pred_logits != F.max_pool1d(pred_logits, 7, 1, 3), -1000 ) # keep maxima with over 0.5 probability (logit > 0) pred_peaks = pred_peaks > 0 # rearrange back to two tensors of shape (B, T) beat_peaks, downbeat_peaks = rearrange( pred_peaks, "(c b) t -> c b t", b=beat.shape[0], t=beat.shape[1], c=2 ) # run the piecewise operations with ThreadPoolExecutor() as executor: postp_beat, postp_downbeat = zip( *executor.map( self._postp_minimal_item, beat_peaks, downbeat_peaks, padding_mask ) ) return postp_beat, postp_downbeat def _postp_minimal_item(self, padded_beat_peaks, padded_downbeat_peaks, mask): """Function to compute the operations that must be computed piece by piece, and cannot be done in batch.""" # unpad the predictions by truncating the padding positions beat_peaks = padded_beat_peaks[mask] downbeat_peaks = padded_downbeat_peaks[mask] # pass from a boolean array to a list of times in frames. beat_frame = torch.nonzero(beat_peaks).cpu().numpy()[:, 0] downbeat_frame = torch.nonzero(downbeat_peaks).cpu().numpy()[:, 0] # remove adjacent peaks beat_frame = deduplicate_peaks(beat_frame, width=1) downbeat_frame = deduplicate_peaks(downbeat_frame, width=1) # convert from frame to seconds beat_time = beat_frame / self.fps downbeat_time = downbeat_frame / self.fps # move the downbeat to the nearest beat if ( len(beat_time) > 0 ): # skip if there are no beats, like in the first training steps for i, d_time in enumerate(downbeat_time): beat_idx = np.argmin(np.abs(beat_time - d_time)) downbeat_time[i] = beat_time[beat_idx] # remove duplicate downbeat times (if some db were moved to the same position) downbeat_time = np.unique(downbeat_time) return beat_time, downbeat_time def postp_dbn(self, beat, downbeat, padding_mask): beat_prob = beat.double().sigmoid() downbeat_prob = downbeat.double().sigmoid() # limit lower and upper bound, since 0 and 1 create problems in the DBN epsilon = 1e-5 beat_prob = beat_prob * (1 - epsilon) + epsilon / 2 downbeat_prob = downbeat_prob * (1 - epsilon) + epsilon / 2 with ThreadPoolExecutor() as executor: postp_beat, postp_downbeat = zip( *executor.map( self._postp_dbn_item, beat_prob, downbeat_prob, padding_mask ) ) return postp_beat, postp_downbeat def _postp_dbn_item(self, padded_beat_prob, padded_downbeat_prob, mask): """Function to compute the operations that must be computed piece by piece, and cannot be done in batch.""" # unpad the predictions by truncating the padding positions beat_prob = padded_beat_prob[mask] downbeat_prob = padded_downbeat_prob[mask] # build an artificial multiclass prediction, as suggested by Böck et al. # again we limit the lower bound to avoid problems with the DBN epsilon = 1e-5 combined_act = np.vstack( ( np.maximum( beat_prob.cpu().numpy() - downbeat_prob.cpu().numpy(), epsilon / 2 ), downbeat_prob.cpu().numpy(), ) ).T # run the DBN dbn_out = self.dbn(combined_act) postp_beat = dbn_out[:, 0] postp_downbeat = dbn_out[dbn_out[:, 1] == 1][:, 0] return postp_beat, postp_downbeat def deduplicate_peaks(peaks, width=1) -> np.ndarray: """ Replaces groups of adjacent peak frame indices that are each not more than `width` frames apart by the average of the frame indices. """ result = [] peaks = map(int, peaks) # ensure we get ordinary Python int objects try: p = next(peaks) except StopIteration: return np.array(result) c = 1 for p2 in peaks: if p2 - p <= width: c += 1 p += (p2 - p) / c # update mean else: result.append(p) p = p2 c = 1 result.append(p) return np.array(result) ================================================ FILE: beat_this/model/roformer.py ================================================ """ Transformer with rotary position embedding, adapted from Phil Wang's repository at https://github.com/lucidrains/BS-RoFormer (under MIT License). """ import torch import torch.nn.functional as F from einops import rearrange from torch import nn from torch.nn import Module, ModuleList # helper functions def exists(val): return val is not None # norm class RMSNorm(Module): def __init__(self, size, dim=-1): super().__init__() self.scale = size**0.5 if dim >= 0: raise ValueError(f"dim must be negative, got {dim}") self.gamma = nn.Parameter(torch.ones((size,) + (1,) * (abs(dim) - 1))) self.dim = dim def forward(self, x): return F.normalize(x, dim=self.dim) * self.scale * self.gamma # feedforward class FeedForward(Module): def __init__( self, dim, mult=4, dropout=0.0, dim_out=None, ): super().__init__() if dim_out is None: dim_out = dim dim_inner = int(dim * mult) self.activation = nn.GELU() self.net = nn.Sequential( RMSNorm(dim), nn.Linear(dim, dim_inner), self.activation, nn.Dropout(dropout), nn.Linear(dim_inner, dim_out), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) # attention class Attend(nn.Module): def __init__(self, dropout=0.0, scale=None): super().__init__() self.dropout = dropout self.scale = scale def forward(self, q, k, v): if exists(self.scale): default_scale = q.shape[-1] ** -0.5 q = q * (self.scale / default_scale) return F.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout if self.training else 0.0 ) class Attention(Module): def __init__( self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, gating=True, ): super().__init__() self.heads = heads self.scale = dim_head**-0.5 dim_inner = heads * dim_head self.rotary_embed = rotary_embed self.attend = Attend(dropout=dropout) self.norm = RMSNorm(dim) self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) if gating: self.to_gates = nn.Linear(dim, heads) else: self.to_gates = None self.to_out = nn.Sequential( nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout) ) def forward(self, x): x = self.norm(x) q, k, v = rearrange( self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads ) if exists(self.rotary_embed): q = self.rotary_embed.rotate_queries_or_keys(q) k = self.rotary_embed.rotate_queries_or_keys(k) out = self.attend(q, k, v) if exists(self.to_gates): gates = self.to_gates(x) out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) # Roformer class Transformer(Module): def __init__( self, *, dim, depth, dim_head=32, heads=16, attn_dropout=0.1, ff_dropout=0.1, ff_mult=4, norm_output=True, rotary_embed=None, gating=True, ): super().__init__() self.layers = ModuleList([]) for _ in range(depth): ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) self.layers.append( ModuleList( [ Attention( dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed, gating=gating, ), ff, ] ) ) self.norm = RMSNorm(dim) if norm_output else nn.Identity() def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x x = self.norm(x) return x ================================================ FILE: beat_this/preprocessing.py ================================================ import numpy as np import torch import torchaudio def load_audio(path, dtype="float64"): try: waveform, samplerate = torchaudio.load(path, channels_first=False) waveform = np.asanyarray(waveform.squeeze().numpy(), dtype=dtype) return waveform, samplerate except Exception: # in case torchaudio fails, try soundfile try: import soundfile as sf return sf.read(path, dtype=dtype) except Exception: # some files are not readable by soundfile, try madmom try: import madmom return madmom.io.load_audio_file(str(path), dtype=dtype) except Exception: raise RuntimeError(f'Could not load audio from "{path}".') class LogMelSpect(torch.nn.Module): def __init__( self, sample_rate=22050, n_fft=1024, hop_length=441, f_min=30, f_max=11000, n_mels=128, mel_scale="slaney", normalized="frame_length", power=1, log_multiplier=1000, device="cpu", ): super().__init__() self.spect_class = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, f_min=f_min, f_max=f_max, n_mels=n_mels, mel_scale=mel_scale, normalized=normalized, power=power, ).to(device) self.log_multiplier = log_multiplier def forward(self, x): """Input is a waveform as a monodimensional array of shape T, output is a 2D log mel spectrogram of shape (F,128).""" return torch.log1p(self.log_multiplier * self.spect_class(x).T) ================================================ FILE: beat_this/utils.py ================================================ from itertools import chain from pathlib import Path import numpy as np def index_to_framewise(index, length): """Convert an index to a framewise sequence""" sequence = np.zeros(length, dtype=bool) sequence[index] = True return sequence def filename_to_augmentation(filename): """Convert a filename to an augmentation factor.""" parts = Path(filename).stem.split("_") augmentations = {} for part in parts[1:]: if part.startswith("ps"): augmentations["shift"] = int(part[2:]) elif part.startswith("ts"): augmentations["stretch"] = int(part[2:]) return augmentations def infer_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray: """ From beat and downbeat times, infer a number for each beat such that each downbeat is associated with a 1 and beats in between are counted upwards. The function requires that all downbeats are also listed as beats. Args: beats (numpy.ndarray): Array of beat positions in seconds (including downbeats). downbeats (numpy.ndarray): Array of downbeat positions in seconds. Returns: numbers (numpy.ndarray): Array of integer beat numbers. """ # check if all downbeats are beats if not np.all(np.isin(downbeats, beats)): raise ValueError("Not all downbeats are beats.") # handle pickup measure, by considering the beat count of the first full measure if len(downbeats) >= 2: # find the number of beats between the first two downbeats first_downbeat, second_downbeat = np.searchsorted(beats, downbeats[:2]) beats_in_first_measure = second_downbeat - first_downbeat # find the number of beats before the first downbeat pickup_beats = first_downbeat # derive where to start counting if pickup_beats < beats_in_first_measure: start_counter = beats_in_first_measure - pickup_beats else: print( "WARNING: There are more beats in the pickup measure than in the first measure. The beat count will start from 2 without trying to estimate the length of the pickup measure." ) start_counter = 1 else: print( "WARNING: There are less than two downbeats in the predictions. Something may be wrong. The beat count will start from 2 without trying to estimate the length of the pickup measure." ) start_counter = 1 # assemble the beat numbers numbers = [] counter = start_counter downbeats = chain(downbeats, [-1]) next_downbeat = next(downbeats) for beat in beats: if beat == next_downbeat: counter = 1 next_downbeat = next(downbeats) else: counter += 1 numbers.append(counter) return np.asarray(numbers) def save_beat_tsv(beats: np.ndarray, downbeats: np.ndarray, outpath: str) -> None: """ Save beat information to a tab-separated file in the standard .beats format: each line has a time in seconds, a tab, and a beat number (1 = downbeat). The function requires that all downbeats are also listed as beats. Args: beats (numpy.ndarray): Array of beat positions in seconds (including downbeats). downbeats (numpy.ndarray): Array of downbeat positions in seconds. outpath (str): Path to the output TSV file. Returns: None """ # infer beat numbers numbers = infer_beat_numbers(beats, downbeats) # write the beat file Path(outpath).parent.mkdir(parents=True, exist_ok=True) try: with open(outpath, "w") as f: f.writelines(f"{beat}\t{number}\n" for beat, number in zip(beats, numbers)) except KeyboardInterrupt: outpath.unlink() # avoid half-written files def replace_state_dict_key(state_dict: dict, old: str, new: str): """Replaces `old` in all keys of `state_dict` with `new`.""" keys = list(state_dict.keys()) # take snapshot of the keys for key in keys: if old in key: state_dict[key.replace(old, new)] = state_dict.pop(key) return state_dict ================================================ FILE: beat_this_example.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyOW4OkTmphTrvw2IQLr+kxP", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "# Beat This! inference example\n", "\n", "We first need to install and load the package." ], "metadata": { "id": "87X_GXfoGwmj" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sxhsMCKdLOLO", "collapsed": true }, "outputs": [], "source": [ "# install the beat_this package\n", "!pip install https://github.com/CPJKU/beat_this/archive/main.zip\n", "# on Google Colab, this one is faster:\n", "#!pip install --no-deps rotary-embedding-torch https://github.com/CPJKU/beat_this/archive/main.zip\n", "\n", "# load the Python class for beat tracking\n", "from beat_this.inference import File2Beats\n", "from beat_this.inference import File2File" ] }, { "cell_type": "markdown", "source": [ "## Run on demo file\n", "\n", "Now that all the dependencies have been installed and imported, let's run our system.\n", "\n", "In the next cell we:\n", "- define the audio file we want to use as input. For now we use the example provided in the beat_this repo, but this can be changed (see instructions later);\n", "- load the File2Beats class that produce a list of beats and downbeats given an audio file;\n", "- apply the class to the audio file\n", "- print the position in seconds of the first 20 beats and first 20 downbeats.\n" ], "metadata": { "id": "_0oYbH6P6Ji7" } }, { "cell_type": "code", "source": [ "!wget -c \"https://github.com/CPJKU/beat_this/raw/main/tests/It%20Don't%20Mean%20A%20Thing%20-%20Kings%20of%20Swing.mp3\"\n", "audio_path = \"/content/It Don't Mean A Thing - Kings of Swing.mp3\"\n", "\n", "file2beats = File2Beats(checkpoint_path=\"final0\", dbn=False)\n", "beats, downbeats = file2beats(audio_path)\n", "\n", "print(\"First 20 beats\", beats[:20])\n", "print(\"First 20 downbeats\", downbeats[:20])" ], "metadata": { "id": "DHT6v-a-TbZx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "We can sonify the beats and downbeats as click on top of the audio file." ], "metadata": { "id": "lRjJFiexDGdn" } }, { "cell_type": "code", "source": [ "import IPython.display as ipd\n", "import librosa\n", "import numpy as np\n", "import soundfile as sf\n", "\n", "audio, sr = sf.read(audio_path)\n", "# make it mono if stereo\n", "if len(audio.shape) > 1:\n", " audio = np.mean(audio, axis=1)\n", "\n", "# sonify the beats and downbeats\n", "# remove the beats that are also downbeats for a nicer sonification\n", "beats = [b for b in beats if b not in downbeats]\n", "audio_beat = librosa.clicks(times = beats, sr=sr, click_freq=1000, length=len(audio))\n", "audio_downbeat = librosa.clicks(times = downbeats, sr=sr, click_freq=1500, length=len(audio))\n", "\n", "ipd.display(ipd.Audio(audio + audio_beat + audio_downbeat, rate=sr))" ], "metadata": { "id": "otG0NS_uCXSo" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Run on your own file\n", "\n", "If you want to run on your own audio files follow the following instructions:\n", "1. Click on the folder icon in the left vertical menu.\n", "2. Click on the \"Upload to session storage\" icon with the upward pointing arrow.\n", "\n", " This will add an audio file to the current colab runtime (it could take some time, and you may need to refresh the file manager using the dedicated button to see the new file). You can copy the audio path by clicking on the three dots next to the file, then \"copy path\".\n", "\n", " For example, if you upload a file called `my_song.mp3`, the path will be `/content/my_song.mp3`.\n", "\n", "3. change the `audio_path` in the cell above with the path of your uploaded audio" ], "metadata": { "id": "hn83Sn1pWmt5" } }, { "cell_type": "markdown", "source": [ "You can also produce a list of beat and downbeat as tsv file, that you can download and import in Sonic Visualizer.\n", "\n", "To do this this, use the File2File function as below:" ], "metadata": { "id": "kP2gyplIEcWT" } }, { "cell_type": "code", "source": [ "file2file = File2File(checkpoint_path=\"final0\", dbn=False)\n", "file2file(audio_path,output_path=\"output.beats\")" ], "metadata": { "id": "kTQK-d4JEbL7" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "As you can see, the system is fast enough to work in a reasonable time even on CPU.\n", "\n", "For even faster inference, you can start a GPU session in Colab!" ], "metadata": { "id": "1Y1d-DvXFtVz" } }, { "cell_type": "markdown", "source": [ "## Batch processing multiple files\n", "\n", "To process multiple of your own audio files, upload them as described above, then run the `beat_this` command line tool:" ], "metadata": { "id": "vpoM0RvQdAMF" } }, { "cell_type": "code", "source": [ "!beat_this --model final0 /content/" ], "metadata": { "id": "qNOLbBplc_Nq" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "It will produce a `.beats` file for every audio file that you can download again." ], "metadata": { "id": "_xNY_9DEdSEt" } } ] } ================================================ FILE: hubconf.py ================================================ dependencies = [ "torch", "torchaudio", "numpy", "rotary_embedding_torch", "einops", "soxr", ] from beat_this.inference import ( load_model as beat_this, BeatThis, Spect2Frames, Audio2Frames, Audio2Beats, File2Beats, File2File, ) ================================================ FILE: launch_scripts/clean_checkpoints.py ================================================ import argparse from pathlib import Path import torch def main(args): # check if output path exists if Path(args.output_path).exists(): print(f"Output path {args.output_path} already exists. Exiting.") return # load the lightning checkpoit checkpoint = torch.load(args.input_path, map_location="cpu") # clean and keep only the keys "state_dict" and "datamodule" to save space checkpoint = { k: v for k, v in checkpoint.items() if k in [ "state_dict", "datamodule_hyper_parameters", "hyper_parameters", "pytorch-lightning_version", ] } # remove the "data_dir" key from "datamodule_hyper_parameters" because it is a # Posix path and creates problems when loading in Windows. if "data_dir" in checkpoint["datamodule_hyper_parameters"]: del checkpoint["datamodule_hyper_parameters"]["data_dir"] # save the cleaned checkpoint torch.save(checkpoint, args.output_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input-path", type=str, required=True) parser.add_argument("--output-path", type=str, required=True) args = parser.parse_args() main(args) ================================================ FILE: launch_scripts/compute_paper_metrics.py ================================================ #!/usr/bin/env python3 import argparse from pathlib import Path import numpy as np from pytorch_lightning import Trainer, seed_everything from beat_this.dataset import BeatDataModule from beat_this.inference import load_checkpoint from beat_this.model.pl_module import PLBeatThis from beat_this.utils import infer_beat_numbers # for repeatability seed_everything(0, workers=True) def main(args): if len(args.models) == 1: print("Single model prediction for", args.models[0]) # single model prediction checkpoint_path = args.models[0] checkpoint = load_checkpoint(checkpoint_path) # create datamodule datamodule = datamodule_setup(checkpoint, args.num_workers, args.datasplit) # create model and trainer model, trainer = plmodel_setup( checkpoint, args.eval_trim_beats, args.dbn, args.gpu ) # predict metrics, dataset, preds, piece = compute_predictions( model, trainer, datamodule.predict_dataloader(), return_preds=args.dump_predictions, ) # compute averaged metrics averaged_metrics = {k: np.mean(v) for k, v in metrics.items()} # compute metrics averaged by dataset dataset_metrics = { k: {d: np.mean(v[dataset == d]) for d in np.unique(dataset)} for k, v in metrics.items() } # print for dataset print("Metrics") for k, v in averaged_metrics.items(): print(f"{k}: {v}") print("Dataset metrics") for k, v in dataset_metrics.items(): print(k) for d, value in v.items(): print(f"{d}: {value}") print("------") # dump predictions if args.dump_predictions: write_predictions(args.dump_predictions, preds, piece) else: # multiple models if args.aggregation_type == "mean-std": if args.dump_predictions: print( "cannot dump predictions when doing inference for multiple models" ) return # computing result variability for the same dataset and different model seeds # create datamodule only once, as we assume it is the same for all models checkpoint = load_checkpoint(args.models[0]) datamodule = datamodule_setup(checkpoint, args.num_workers, args.datasplit) # create model and trainer all_metrics = [] for checkpoint_path in args.models: checkpoint = load_checkpoint(checkpoint_path) model, trainer = plmodel_setup( checkpoint, args.eval_trim_beats, args.dbn, args.gpu ) metrics, dataset, preds, piece = compute_predictions( model, trainer, datamodule.predict_dataloader() ) # compute averaged metrics for one model averaged_metrics = {k: np.mean(v) for k, v in metrics.items()} all_metrics.append(averaged_metrics) # compute mean and standard deviations for all model averages all_metrics_mean = { k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0] } all_metrics_std = { k: np.std([m[k] for m in all_metrics]) for k in all_metrics[0] } all_metrics_stats = { k: (all_metrics_mean[k], all_metrics_std[k]) for k, v in all_metrics[0].items() } # print all metrics print("Metrics") for k, v in all_metrics_stats.items(): # round to 3 decimal places print(f"{k}: {round(v[0],3)} +- {round(v[1],3)}") elif args.aggregation_type == "k-fold": # computing results in the K-fold setting. Every fold has a different dataset all_piece_metrics = [] all_piece_dataset = [] all_piece_preds = [] all_piece = [] # create datamodule for each model for i_model, checkpoint_path in enumerate(args.models): print(f"Model {i_model+1}/{len(args.models)}") checkpoint = load_checkpoint(checkpoint_path) datamodule = datamodule_setup( checkpoint, args.num_workers, args.datasplit ) # create model and trainer model, trainer = plmodel_setup( checkpoint, args.eval_trim_beats, args.dbn, args.gpu ) # predict metrics, dataset, preds, piece = compute_predictions( model, trainer, datamodule.predict_dataloader(), return_preds=args.dump_predictions, ) all_piece_metrics.append(metrics) all_piece_dataset.append(dataset) all_piece_preds.extend(preds) all_piece.append(piece) # aggregate across folds all_piece_metrics = { k: np.concatenate([m[k] for m in all_piece_metrics]) for k in all_piece_metrics[0] } all_piece_dataset = np.concatenate(all_piece_dataset) all_piece = np.concatenate(all_piece) # double check that there are no errors in the fold and there are not repeated pieces assert len(all_piece) == len( np.unique(all_piece) ), "There are repeated pieces in the folds" dataset_metrics = { k: { d: np.mean(v[all_piece_dataset == d]) for d in np.unique(all_piece_dataset) } for k, v in all_piece_metrics.items() } # print for dataset print("Dataset metrics") for k, v in dataset_metrics.items(): print(k) for d, value in v.items(): print(f"{d}: {round(value,3)}") print("------") # dump predictions if args.dump_predictions: write_predictions(args.dump_predictions, all_piece_preds, all_piece) else: raise ValueError(f"Unknown aggregation type {args.aggregation_type}") def datamodule_setup(checkpoint, num_workers, datasplit): # Load the datamodule print("Creating datamodule") data_dir = Path(__file__).parent.parent.relative_to(Path.cwd()) / "data" datamodule_hparams = checkpoint["datamodule_hyper_parameters"] # update the hparams with the ones from the arguments if num_workers is not None: datamodule_hparams["num_workers"] = num_workers datamodule_hparams["predict_datasplit"] = datasplit datamodule_hparams["data_dir"] = data_dir datamodule = BeatDataModule(**datamodule_hparams) datamodule.setup(stage="predict") return datamodule def plmodel_setup(checkpoint, eval_trim_beats, dbn, gpu): """ Set up the pytorch lightning model and trainer for evaluation. Args: checkpoint_path (dict): The dict containing the checkpoint to load. eval_trim_beats (int or None): The number of beats to trim during evaluation. If None, the setting is taken from the pretrained model. dbn (bool or None): Whether to use the Dynamic Bayesian Network (DBN) module during evaluation. If None, the default behavior from the pretrained model is used. gpu (int): The index of the GPU device to use for training. Returns: tuple: A tuple containing the initialized pytorch lightning model and trainer. """ if eval_trim_beats is not None: checkpoint["hyper_parameters"]["eval_trim_beats"] = eval_trim_beats if dbn is not None: checkpoint["hyper_parameters"]["use_dbn"] = dbn model = PLBeatThis(**checkpoint["hyper_parameters"]) model.load_state_dict(checkpoint["state_dict"]) # set correct device and accelerator if gpu >= 0: devices = [gpu] accelerator = "gpu" else: devices = 1 accelerator = "cpu" # create trainer trainer = Trainer( accelerator=accelerator, devices=devices, logger=None, deterministic=True, precision="16-mixed", ) return model, trainer def compute_predictions(model, trainer, predict_dataloader, return_preds=False): print("Computing predictions ...") out = trainer.predict(model, predict_dataloader) metrics = [o[0] for o in out] if return_preds: preds = [model.postprocessor(o[1]["beat"][0], o[1]["downbeat"][0]) for o in out] else: preds = None dataset = np.asarray([o[2][0] for o in out]) piece = np.asarray([o[3][0] for o in out]) # convert metrics from list of per-batch dictionaries to a single dictionary with np arrays as values metrics = {k: np.asarray([m[k] for m in metrics]) for k in metrics[0]} return metrics, dataset, preds, piece def write_predictions(fn, preds, piece): np.savez( fn, **{ name: np.vstack([beats, infer_beat_numbers(beats, downbeats)]).T for name, (beats, downbeats) in zip(piece, preds) }, ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Computes predictions for a given model and dataset, " "prints metrics, and optionally dumps predictions to a given file." ) parser.add_argument( "--models", type=str, nargs="+", required=True, help="Local checkpoint files to use", ) parser.add_argument( "--datasplit", type=str, choices=("train", "val", "test"), default="val", help="data split to use: train, val or test " "(default: %(default)s)", ) parser.add_argument("--gpu", type=int, default=0) parser.add_argument( "--num_workers", type=int, default=8, help="number of data loading workers " ) parser.add_argument( "--eval_trim_beats", metavar="SECONDS", type=float, default=None, help="Override whether to skip the first given seconds " "per piece in evaluating (default: as stored in model)", ) parser.add_argument( "--dbn", default=None, action=argparse.BooleanOptionalAction, help="override the option to use madmom postprocessing dbn", ) parser.add_argument( "--aggregation-type", type=str, choices=("mean-std", "k-fold"), default="mean-std", help="Type of aggregation to use for multiple models; ignored if only one model is given", ) parser.add_argument( "--dump-predictions", metavar="FILENAME", type=str, default=None, help="File to write predictions to, in .npz format (optional)", ) args = parser.parse_args() main(args) ================================================ FILE: launch_scripts/preprocess_audio.py ================================================ #!/usr/bin/env python3 import argparse import concurrent.futures import os from pathlib import Path from zipfile import ZipFile import numpy as np import pandas as pd import soxr import torch import torchaudio from pedalboard import Pedalboard, PitchShift, time_stretch from tqdm import tqdm from beat_this.dataset.augment import precomputed_augmentation_filenames from beat_this.preprocessing import LogMelSpect, load_audio os.environ["CUDA_VISIBLE_DEVICES"] = "-1" BASEPATH = Path(__file__).parent.parent.relative_to(Path.cwd()) def save_audio(path, waveform, samplerate, resample_from=None): if resample_from and resample_from != samplerate: waveform = soxr.resample(waveform, in_rate=resample_from, out_rate=samplerate) try: waveform = torch.as_tensor(np.asarray(waveform, dtype=np.float64)) torchaudio.save( path, torch.atleast_2d(waveform), samplerate, bits_per_sample=16 ) except KeyboardInterrupt: path.unlink() # avoid half-written files raise def save_spectrogram(path, spectrogram, dtype=np.float16): try: np.save(path, np.asarray(spectrogram, dtype=dtype)) except KeyboardInterrupt: path.unlink() # avoid half-written files raise class SpectCreation: def __init__(self, pitch_shift, time_stretch, audio_sr, mel_args, verbose=False): """ Initialize the SpectCreation class. This assume that the audio files have been preprocessed with all the requested augmentations and are stored in the `mono_tracks` directory with the proper naming defined in AudioPreprocessing. Args: pitch_shift (tuple or None): A tuple specifying the minimum and maximum (inclusive) pitch shift values considered from the available audio files. If None, pitch shifting augmentation files will not be considered. time_stretch (tuple or None): A tuple specifying the min/max and stride percentage to consider from the available audio files. If None, time stretching augmentation files will not be considered. audio_sr (int): The sample rate of the audio. mel_args (dict): A dictionary of arguments to be passed to the MelSpectrogram class. verbose (bool, optional): Whether to print verbose information. Defaults to False. """ super(SpectCreation, self).__init__() # define the directories self.audio_dir = BASEPATH / "data" / "audio" self.mono_tracks_dir = self.audio_dir / "mono_tracks" self.spectrograms_dir = self.audio_dir / "spectrograms" self.annotations_dir = BASEPATH / "data" / "annotations" if verbose: print("Audio dir: ", self.audio_dir.absolute()) print("Mono tracks dir: ", self.mono_tracks_dir.absolute()) print("Spectrograms dir: ", self.spectrograms_dir.absolute()) print("Annotations dir: ", self.annotations_dir.absolute()) self.verbose = verbose # remember the audio metadata self.audio_sr = audio_sr # create the mel spectrogram class self.logspect_class = LogMelSpect(audio_sr, **mel_args) # define the augmentations self.augmentations = {} if pitch_shift is not None: self.augmentations["pitch"] = {"min": pitch_shift[0], "max": pitch_shift[1]} if time_stretch is not None: self.augmentations["tempo"] = { "min": -time_stretch[0], "max": time_stretch[0], "stride": time_stretch[1], } # compute the names to consider according to the augmentations self.filenames = precomputed_augmentation_filenames(self.augmentations, "wav") def create_spects(self): print("Creating spectrograms ...") processed = 0 with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for dataset_dir in self.mono_tracks_dir.iterdir(): for piece_dir in dataset_dir.iterdir(): futures.append( executor.submit( self.create_spect_piece, piece_dir, Path(dataset_dir.name) / "annotations" / "beats" / f"{piece_dir.name}.beats", dataset_dir.name, ) ) for future in tqdm( concurrent.futures.as_completed(futures), total=len(futures) ): if future.result(): processed += 1 print(f"Created {processed} spectrograms in {self.spectrograms_dir}") def create_spect_piece(self, preprocessed_audio_folder, beat_path, dataset_name): """ Create spectrogram for a single audio piece. This method creates a spectrogram for a single audio piece located in the `preprocessed_audio_folder`. The beat annotations for the audio piece are loaded from the `beat_path` file. The created spectrogram is saved in the `spectrograms_dir` directory. Args: preprocessed_audio_folder (Path): The path to the preprocessed audio folder. beat_path (Path): The path to the beat annotations file. dataset_name (str): The name of the dataset. Returns: metadata (list): A list containing the metadata of the created spectrogram. """ for filename in self.filenames: if not (self.annotations_dir / beat_path).exists(): print( f"beat annotation {beat_path} not found for {preprocessed_audio_folder}" ) return audio_path = preprocessed_audio_folder / filename spect_path = ( self.spectrograms_dir / dataset_name / preprocessed_audio_folder.name / f"{Path(filename).stem}.npy" ) if spect_path.exists(): if self.verbose: print(f"Skipping {spect_path} because it exists") else: if self.verbose: print(f"Computing {spect_path}") waveform, sr = load_audio(audio_path) assert ( sr == self.audio_sr ), f"Sample rate mismatch: {sr} != {self.audio_sr}" # compute the mel spectrogram and scale the values with log(1 + 1000 * x) spect = self.logspect_class(torch.tensor(waveform, dtype=torch.float32)) # save the spectrogram as numpy array spect_path.parent.mkdir(parents=True, exist_ok=True) save_spectrogram(spect_path, spect.numpy()) return True class AudioPreprocessing(object): def __init__( self, orig_audio_paths, out_sr=22050, aug_sr=44100, ext="wav", pitch_shift=(-5, 6), time_stretch=(20, 4), verbose=False, ): """ Class for converting audio files to mono, resampling, and applying augmentations. Only use this if you want to start from new audio files, otherwise use the spectrograms provided in the repo. Args: orig_audio_paths (Path): The path to the file with the original audio paths for each dataset. out_sr (int, optional): The output sample rate. Defaults to 22050. aug_sr (int, optional): The sample rate for the augmentations. Defaults to 44100. ext (str, optional): The extension of the audio files. Defaults to 'wav'. pitch_shift (tuple, optional): A tuple specifying the minimum and maximum (inclusive) pitch shift values considered. Defaults to (-5, 6). time_stretch (tuple, optional): A tuple specifying the min/max (inclusive) time stretch and stride in percentage considered. Defaults to (20, 4). verbose (bool, optional): Whether to print verbose information. Defaults to False. """ super(AudioPreprocessing, self).__init__() self.audio_dir = BASEPATH / "data" / "audio" self.annotation_dir = BASEPATH / "data" / "annotations" # load data_dir from audio_path.csv which has the format: dataset_name, audio_path self.audio_dirs = { row[0]: row[1] for row in pd.read_csv(orig_audio_paths, header=None).values } # check if annotations exists, otherwise tell how to obtain them if not self.annotation_dir.exists(): raise RuntimeError( f"{self.annotation_dir} missing, check instructions " "in README.md how to obtain the annotations." ) print(f"Annotations ready in {self.annotation_dir}") self.out_sr = out_sr self.aug_sr = aug_sr self.ext = ext self.pitch_shift = pitch_shift if time_stretch: # interpret tuple as (maximum percentage, stride) time_stretch = range( -time_stretch[0], time_stretch[0] + 1, time_stretch[1] if len(time_stretch) > 1 else 1, ) self.time_stretch = time_stretch self.verbose = verbose def preprocess_audio(self): print("Preprocessing audio files ...") processed = 0 with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for dataset_name, audio_dir in self.audio_dirs.items(): for audio_path in Path(audio_dir).iterdir(): if audio_path.stem[:12] in ("gtzan_speech", "gtzan_music_"): continue futures.append( executor.submit( self.process_audio_file, dataset_name, audio_path ) ) for future in tqdm( concurrent.futures.as_completed(futures), total=len(futures) ): if future.result(): processed += 1 print("Processed", processed, "audio files") def process_audio_file(self, dataset_name, audio_path): annotation_dir = Path(self.annotation_dir, dataset_name, "annotations") # load annotations beat_path = Path(annotation_dir, "beats", audio_path.stem + ".beats") if not beat_path.exists(): print( f"beat annotation {beat_path} not found for {audio_path}", ) return False # create a folder with the name of the track folder_path = Path(self.audio_dir, "mono_tracks", dataset_name, audio_path.stem) # derive the name of the unaugmented file mono_path = folder_path / f"track.{self.ext}" # derive the name of all augmented files augmentations = { "pitch": {"min": self.pitch_shift[0], "max": self.pitch_shift[1]}, "tempo": { "min": -self.time_stretch[0], "max": self.time_stretch[0], "stride": self.time_stretch[1], }, } augmentations_path = precomputed_augmentation_filenames(augmentations, self.ext) # stop here if all files exists if mono_path.exists() and all( (folder_path / aug).exists() for aug in augmentations_path ): if self.verbose: print(f"All files in {folder_path} exists, skipping") return True # load audio try: waveform, sr = load_audio(audio_path) except Exception as e: print("Problem with loading waveform", audio_path, e) return folder_path.mkdir(parents=True, exist_ok=True) if ( waveform.ndim == 1 and sr == self.out_sr and audio_path.suffix == f".{self.ext}" ): # shortcut: copy original file to mono path location os.system("cp '{}' '{}'".format(audio_path, mono_path)) else: # we need to do some conversions for the unaugmented file if waveform.ndim != 1: waveform = np.mean(waveform, axis=1) if not mono_path.exists(): if sr != self.out_sr: waveform_out = soxr.resample( waveform, in_rate=sr, out_rate=self.out_sr ) else: waveform_out = waveform # save mono file save_audio(mono_path, waveform_out, self.out_sr) if (self.pitch_shift or self.time_stretch) and (sr != self.aug_sr): waveform = soxr.resample(waveform, in_rate=sr, out_rate=self.aug_sr) # handle the requested augmentations # pedalboard requires float32, convert waveform = np.asarray(waveform, dtype=np.float32) shifts = ( range(self.pitch_shift[0], self.pitch_shift[1] + 1) if self.pitch_shift else [0] ) stretches = self.time_stretch if self.time_stretch else [0] for shift in shifts: # pitch augmentation augment_audio_file( folder_path, waveform, aug_type="shift", amount=shift, aug_sr=self.aug_sr, out_sr=self.out_sr, ext=self.ext, verbose=self.verbose, ) for stretch in stretches: # tempo augmentation augment_audio_file( folder_path, waveform, aug_type="stretch", amount=stretch, aug_sr=self.aug_sr, out_sr=self.out_sr, ext=self.ext, verbose=self.verbose, ) return True def augment_audio_file( folder_path, waveform, aug_type, amount, aug_sr, out_sr, ext, verbose ): # figure out the file name if aug_type == "stretch": stretch = amount shift = 0 elif aug_type == "shift": shift = amount stretch = 0 else: raise ValueError(f"Unknown augmentation mode {aug_type}") suffix = "" if shift != 0: suffix = suffix + f"_ps{shift}" if stretch != 0: suffix = suffix + f"_ts{stretch}" out_path = Path(folder_path, f"track{suffix}.{ext}") # skip if it exists if out_path.exists(): if verbose: print(f"{out_path} exists, skipping") return # otherwise compute it and write it out # time stretch or pitch shift alone if aug_type == "shift": if verbose: print(f"computing {out_path} with {shift=}") # pitch shift alone board = Pedalboard( [ PitchShift(semitones=shift), ] ) # apply pedalboard augmented = board(waveform, aug_sr) else: # type == stretch if verbose: print(f"computing {out_path} with {stretch=}") augmented = time_stretch( waveform, aug_sr, stretch_factor=1 + stretch / 100, pitch_shift_in_semitones=0.0, ).squeeze() # save to file if verbose: print(f"writing {out_path}") save_audio(out_path, augmented, out_sr, resample_from=aug_sr) def create_npz(spect_dir, npz_file, augmentations, verbose): """Assemble spectrograms from a directory into an .npz file.""" if npz_file.exists(): if verbose: print(f"{npz_file} already exists, skipping") return with ZipFile(npz_file, "w") as z: for subdir in tqdm(sorted(spect_dir.iterdir()), leave=False): if subdir.is_dir(): for fn in precomputed_augmentation_filenames(augmentations): z.write(subdir / fn, subdir.name + "/" + fn) def ints(value): """Parse a string containing a colon-separated tuple of integers.""" return value and tuple(map(int, value.split(":"))) def main(orig_audio_paths, pitch_shift, time_stretch, verbose): # preprocess audio dp = AudioPreprocessing( orig_audio_paths=orig_audio_paths, out_sr=22050, aug_sr=44100, pitch_shift=pitch_shift, time_stretch=time_stretch, verbose=verbose, ) dp.preprocess_audio() # compute spectrograms mel_args = dict( n_fft=1024, hop_length=441, f_min=30, f_max=11000, n_mels=128, mel_scale="slaney", normalized="frame_length", power=1, ) sc = SpectCreation( pitch_shift=pitch_shift, time_stretch=time_stretch, audio_sr=22050, mel_args=mel_args, verbose=verbose, ) sc.create_spects() # assemble into NPZ files print("Creating .npz spectrogram bundles...") spect_dirs = [child for child in sc.spectrograms_dir.iterdir() if child.is_dir()] for spect_dir in tqdm(spect_dirs): create_npz( spect_dir, spect_dir.with_suffix(".npz"), {} if spect_dir.name == "gtzan" else sc.augmentations, verbose, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--orig_audio_paths", type=str, help="path to the file with the original audio paths for each dataset (default: %(default)s)", default="data/audio_paths.csv", ) parser.add_argument( "--pitch_shift", metavar="LOW:HIGH", type=str, default="-5:6", help="pitch shift in semitones (default: %(default)s)", ) parser.add_argument( "--time_stretch", metavar="MAX:STRIDE", type=str, default="20:4", help="time stretch in percentage and stride (default: %(default)s)", ) parser.add_argument("--verbose", action="store_true", help="verbose output") args = parser.parse_args() main( args.orig_audio_paths, ints(args.pitch_shift), ints(args.time_stretch), args.verbose, ) ================================================ FILE: launch_scripts/train.py ================================================ import argparse from pathlib import Path import torch from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from beat_this.dataset import BeatDataModule from beat_this.model.pl_module import PLBeatThis def main(args): # for repeatability seed_everything(args.seed, workers=True) print("Starting a new run with the following parameters:") print(args) params_str = f"{'noval ' if not args.val else ''}{'hung ' if args.hung_data else ''}{'fold' + str(args.fold) + ' ' if args.fold is not None else ''}{args.loss}-h{args.transformer_dim}-aug{args.tempo_augmentation}{args.pitch_augmentation}{args.mask_augmentation}{' nosumH ' if not args.sum_head else ''}{' nopartialT ' if not args.partial_transformers else ''}" if args.logger == "wandb": if args.resume_checkpoint and args.resume_id: wandb_args = dict(id=args.resume_id, resume="must") else: wandb_args = {} logger = WandbLogger( project="beat_this", name=f"{args.name} {params_str}".strip(), **wandb_args ) else: logger = None if args.force_flash_attention: print("Forcing the use of the flash attention.") torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(False) data_dir = Path(__file__).parent.parent.relative_to(Path.cwd()) / "data" checkpoint_dir = ( Path(__file__).parent.parent.relative_to(Path.cwd()) / "checkpoints" ) augmentations = {} if args.tempo_augmentation: augmentations["tempo"] = {"min": -20, "max": 20, "stride": 4} if args.pitch_augmentation: augmentations["pitch"] = {"min": -5, "max": 6} if args.mask_augmentation: # kind, min_count, max_count, min_len, max_len, min_parts, max_parts augmentations["mask"] = { "kind": "permute", "min_count": 1, "max_count": 6, "min_len": 0.1, "max_len": 2, "min_parts": 5, "max_parts": 9, } datamodule = BeatDataModule( data_dir, batch_size=args.batch_size, train_length=args.train_length, spect_fps=args.fps, num_workers=args.num_workers, test_dataset="gtzan", length_based_oversampling_factor=args.length_based_oversampling_factor, augmentations=augmentations, hung_data=args.hung_data, no_val=not args.val, fold=args.fold, ) datamodule.setup(stage="fit") # compute positive weights pos_weights = datamodule.get_train_positive_weights(widen_target_mask=3) print("Using positive weights: ", pos_weights) dropout = { "frontend": args.frontend_dropout, "transformer": args.transformer_dropout, } pl_model = PLBeatThis( spect_dim=128, fps=50, transformer_dim=args.transformer_dim, ff_mult=4, n_layers=args.n_layers, stem_dim=32, dropout=dropout, lr=args.lr, weight_decay=args.weight_decay, pos_weights=pos_weights, head_dim=32, loss_type=args.loss, warmup_steps=args.warmup_steps, max_epochs=args.max_epochs, use_dbn=args.dbn, eval_trim_beats=args.eval_trim_beats, sum_head=args.sum_head, partial_transformers=args.partial_transformers, ) for part in args.compile: if hasattr(pl_model.model, part): setattr(pl_model.model, part, torch.compile(getattr(pl_model.model, part))) print("Will compile model", part) else: raise ValueError("The model is missing the part", part, "to compile") callbacks = [LearningRateMonitor(logging_interval="step")] # save only the last model callbacks.append( ModelCheckpoint( every_n_epochs=1, dirpath=str(checkpoint_dir), filename=f"{args.name} S{args.seed} {params_str}".strip(), ) ) trainer = Trainer( max_epochs=args.max_epochs, accelerator="auto", devices=[args.gpu], num_sanity_val_steps=1, logger=logger, callbacks=callbacks, log_every_n_steps=1, precision="16-mixed", accumulate_grad_batches=args.accumulate_grad_batches, check_val_every_n_epoch=args.val_frequency, ) trainer.fit(pl_model, datamodule, ckpt_path=args.resume_checkpoint) trainer.test(pl_model, datamodule) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, default="") parser.add_argument("--gpu", type=int, default=0) parser.add_argument( "--force-flash-attention", default=False, action=argparse.BooleanOptionalAction ) parser.add_argument( "--compile", action="store", nargs="*", type=str, default=["frontend", "transformer_blocks", "task_heads"], help="Which model parts to compile, among frontend, transformer_encoder, task_heads", ) parser.add_argument("--n-layers", type=int, default=6) parser.add_argument("--transformer-dim", type=int, default=512) parser.add_argument( "--frontend-dropout", type=float, default=0.1, help="dropout rate to apply in the frontend", ) parser.add_argument( "--transformer-dropout", type=float, default=0.2, help="dropout rate to apply in the main transformer blocks", ) parser.add_argument("--lr", type=float, default=0.0008) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--logger", type=str, choices=["wandb", "none"], default="none") parser.add_argument("--num-workers", type=int, default=8) parser.add_argument("--n-heads", type=int, default=16) parser.add_argument("--fps", type=int, default=50, help="The spectrograms fps.") parser.add_argument( "--loss", type=str, default="shift_tolerant_weighted_bce", choices=[ "shift_tolerant_weighted_bce", "fast_shift_tolerant_weighted_bce", "weighted_bce", "bce", ], help="The loss to use", ) parser.add_argument( "--warmup-steps", type=int, default=1000, help="warmup steps for optimizer" ) parser.add_argument( "--max-epochs", type=int, default=100, help="max epochs for training" ) parser.add_argument( "--batch-size", type=int, default=8, help="batch size for training" ) parser.add_argument("--accumulate-grad-batches", type=int, default=8) parser.add_argument( "--train-length", type=int, default=1500, help="maximum seq length for training in frames", ) parser.add_argument( "--dbn", default=False, action=argparse.BooleanOptionalAction, help="use madmom postprocessing DBN", ) parser.add_argument( "--eval-trim-beats", metavar="SECONDS", type=float, default=5, help="Skip the first given seconds per piece in evaluating (default: %(default)s)", ) parser.add_argument( "--val-frequency", metavar="N", type=int, default=5, help="validate every N epochs (default: %(default)s)", ) parser.add_argument( "--tempo-augmentation", default=True, action=argparse.BooleanOptionalAction, help="Use precomputed tempo aumentation", ) parser.add_argument( "--pitch-augmentation", default=True, action=argparse.BooleanOptionalAction, help="Use precomputed pitch aumentation", ) parser.add_argument( "--mask-augmentation", default=True, action=argparse.BooleanOptionalAction, help="Use online mask aumentation", ) parser.add_argument( "--sum-head", default=True, action=argparse.BooleanOptionalAction, help="Use SumHead instead of two separate Linear heads", ) parser.add_argument( "--partial-transformers", default=True, action=argparse.BooleanOptionalAction, help="Use Partial transformers in the frontend", ) parser.add_argument( "--length-based-oversampling-factor", type=float, default=0.65, help="The factor to oversample the long pieces in the dataset. Set to 0 to only take one excerpt for each piece.", ) parser.add_argument( "--val", default=True, action=argparse.BooleanOptionalAction, help="Train on all data, including validation data, escluding test data. The validation metrics will still be computed, but they won't carry any meaning.", ) parser.add_argument( "--hung-data", default=False, action=argparse.BooleanOptionalAction, help="Limit the training to Hung et al. data. The validation will still be computed on all datasets.", ) parser.add_argument( "--fold", type=int, default=None, help="If given, the CV fold number to *not* train on (0-based).", ) parser.add_argument( "--seed", type=int, default=0, help="Seed for the random number generators.", ) parser.add_argument( "--resume-checkpoint", type=str, default=None, help="Resume training from a local checkpoint.", ) parser.add_argument( "--resume-id", type=str, default=None, help="When resuming with --resume-checkpoint, optionally provide the wandb id to continue logging to.", ) args = parser.parse_args() main(args) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] name = "beat-this" version = "1.1.0" description = "Beat This! beat tracker" readme = "README.md" classifiers = [ "Intended Audience :: Science/Research", "Topic :: Multimedia :: Sound/Audio :: Analysis", "Development Status :: 5 - Production/Stable", "Programming Language :: Python :: 3", ] authors = [ {name = "Francesco Foscarin", email = "francesco.foscarin@jku.at"}, {name = "Jan Schlüter", email = "jan.schlueter@jku.at"}, ] requires-python = ">=3" dependencies = [ "numpy>=1.20", "torch>=2", "torchaudio", "einops", "rotary-embedding-torch", "soxr", ] license = "MIT" license-files = ["LICENSE"] [tool.setuptools.package-dir] beat_this = "beat_this" [project.urls] Repository = "https://github.com/CPJKU/beat_this" Issues = "https://github.com/CPJKU/beat_this/issues" Changelog = "https://github.com/CPJKU/beat_this/blob/main/CHANGELOG.md" [project.scripts] beat_this = "beat_this.cli:main" ================================================ FILE: requirements.txt ================================================ # This is a set of known working versions for inference, documented for a # distant future. # We recommend following the requirements section in our README.md instead. einops==0.8.0 numpy==1.26.4 rotary_embedding_torch==0.6.4 soxr==0.3.7 torch==2.3.1 torchaudio==2.3.1 tqdm==4.66.4 ================================================ FILE: tests/test_inference.py ================================================ from pathlib import Path import numpy as np import soundfile as sf import torch from beat_this.inference import Audio2Frames, File2Beats def test_File2Beat(): f2b = File2Beats() audio_path = Path("tests/It Don't Mean A Thing - Kings of Swing.mp3") beat, downbeat = f2b(audio_path) assert isinstance(beat, np.ndarray) assert isinstance(downbeat, np.ndarray) def test_Audio2Frames(): a2f = Audio2Frames() audio_path = Path("tests/It Don't Mean A Thing - Kings of Swing.mp3") # load audio audio, sr = sf.read(audio_path) beat, downbeat = a2f(audio, sr) assert isinstance(beat, torch.Tensor) assert isinstance(downbeat, torch.Tensor)