Repository: CPJKU/beat_this
Branch: main
Commit: 9d787b9797ea
Files: 29
Total size: 172.9 KB
Directory structure:
gitextract_r0d2sw0_/
├── .github/
│ └── workflows/
│ └── pypi.yml
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── beat_this/
│ ├── __init__.py
│ ├── cli.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── augment.py
│ │ ├── dataset.py
│ │ └── mmnpz.py
│ ├── inference.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── beat_tracker.py
│ │ ├── loss.py
│ │ ├── pl_module.py
│ │ ├── postprocessor.py
│ │ └── roformer.py
│ ├── preprocessing.py
│ └── utils.py
├── beat_this_example.ipynb
├── hubconf.py
├── launch_scripts/
│ ├── clean_checkpoints.py
│ ├── compute_paper_metrics.py
│ ├── preprocess_audio.py
│ └── train.py
├── pyproject.toml
├── requirements.txt
└── tests/
└── test_inference.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/pypi.yml
================================================
# This workflow will upload a Python Package to PyPI when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
release-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v5
with:
python-version: "3.x"
- name: Build release distributions
run: |
python -m pip install build
python -m build
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
pypi-publish:
runs-on: ubuntu-latest
needs:
- release-build
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write
# Dedicated environments with protections for publishing are strongly recommended.
# For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules
environment:
name: pypi
url: https://pypi.org/project/beat-this/${{ github.event.release.name }}
steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v5
with:
name: release-dists
path: dist/
- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: dist/
================================================
FILE: .gitignore
================================================
__pycache__/
*.py[cod]
*$py.class
data/
checkpoints/
lightning_logs/
wandb/
.vscode/
beat_this.egg-info/
build/
================================================
FILE: CHANGELOG.md
================================================
# Changelog
All notable changes to this project are documented below.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.1.0] - 2026-04-14
- Clarified installation instructions for madmom and mir_eval
- Load checkpoints with `weights_only=True` when supported
- Fix checkpoint downloads after server-side update
- Provide separate `infer_beat_numbers()` function
- Command-line tool: Support saving raw activations / logits
- Training script: Support resuming from previous checkpoint
- Migrate to pyproject.toml (thanks to @JacobLinCool)
- Support non-CUDA accelerator chips (thanks to @tillt)
- Published on PyPI (thanks to @MarvinSchenkel)
## [1.0] - 2024-10-18
- Initial release
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 Institute of Computational Perception, JKU Linz, Austria
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Beat This!
Official implementation of the beat tracker from the ISMIR 2024 paper "[Beat This! Accurate Beat Tracking Without DBN Postprocessing](https://arxiv.org/abs/2407.21658)" by Francesco Foscarin, Jan Schlüter and Gerhard Widmer.
* [Inference](#inference)
* [Available models](#available-models)
* [Data](#data)
* [Reproducing metrics from the paper](#reproducing-metrics-from-the-paper)
* [Training](#training)
* [Reusing the loss](#reusing-the-loss)
* [Reusing the model](#reusing-the-model)
* [Citation](#citation)
## Inference
To predict beats for audio files, you can either use our command line tool or call the beat tracker from Python. Both have the same requirements unless you go for the online demo.
### Online demo
To process a small set of audio files without installing anything, [open our example notebook in Google Colab](https://colab.research.google.com/github/CPJKU/beat_this/blob/main/beat_this_example.ipynb) and follow the instructions.
### Requirements
The beat tracker requires Python with a set of packages installed:
1. [Install PyTorch](https://pytorch.org/get-started/locally/) 2.0 or later following the instructions for your platform.
2. Install further modules with `pip install tqdm einops soxr rotary-embedding-torch`. (If using conda, we still recommend pip. You may try installing `soxr-python` and `einops` from conda-forge, but `rotary-embedding-torch` is only on PyPI.)
3. To read other audio formats than `.wav`, install `ffmpeg` or another supported backend for `torchaudio`. (`ffmpeg` can be installed via conda or via your operating system.)
Finally, install our beat tracker with:
```bash
pip install beat-this
```
For the development version, use:
```bash
pip install https://github.com/CPJKU/beat_this/archive/main.zip
```
### Command line
Along with the python package, a command line application called `beat_this` is installed. For a full documentation of the command line options, run:
```bash
beat_this --help
```
The basic usage is:
```bash
beat_this path/to/audio.file -o path/to/output.beats
```
To process multiple files, specify multiple input files or directories, and give an output directory instead:
```bash
beat_this path/to/*.mp3 path/to/whole_directory/ -o path/to/output_directory
```
The beat tracker will use the first GPU in your system by default, and fall back to CPU if PyTorch does not have CUDA access. With `--gpu=2`, it will use the third GPU, and with `--gpu=-1` it will force the CPU. For recent GPUs, passing `--float16` may improve speed.
If you have a lot of files to process, you can distribute the load over multiple processes by running the same command multiple times with `--touch-first`, `--skip-existing` and potentially different options for `--gpu`:
```bash
for gpu in {0..3}; do beat_this input_dir -o output_dir --touch-first --skip-existing --gpu=$gpu & done
```
If you want to use the DBN for postprocessing, add `--dbn`. The DBN parameters are the default ones from madmom. This requires installing the `madmom` package (with `pip install git+https://github.com/CPJKU/madmom.git`, as the current version on PyPI only supports Python<3.10 and numpy<1.20).
### Python class
If you are a Python user, you can directly use the `beat_this.inference` module.
First, instantiate an instance of the `File2Beats` class that encapsulates the model along with pre- and postprocessing:
```python
from beat_this.inference import File2Beats
file2beats = File2Beats(checkpoint_path="final0", device="cuda", dbn=False)
```
To obtain a list of beats and downbeats for an audio file, run:
```python
audio_path = "path/to/audio.file"
beats, downbeats = file2beats(audio_path)
```
Optionally, you can produce a `.beats` file (e.g., for importing into [Sonic Visualizer](https://www.sonicvisualiser.org/)):
```python
from beat_this.utils import save_beat_tsv
outpath = "path/to/output.beats"
save_beat_tsv(beats, downbeats, outpath)
```
If you already have an audio tensor loaded, instead of `File2Beats`, use `Audio2Beats` and pass the tensor and its sample rate. We also provide `Audio2Frames` for framewise logits and `Spect2Frames` for spectrogram inputs.
## Available models
Models are available for manual download at [our cloud space](https://cloud.cp.jku.at/index.php/s/7ik4RrBKTS273gp), but will also be downloaded automatically by the above inference code. By default, the inference will use `final0`, but it is possible to select another model via a command line option (`--model`) or Python parameter (`checkpoint_path`).
Main models:
* `final0`, `final1`, `final2`: Our main model, trained on all data except the GTZAN dataset, with three different seeds. This corresponds to "Our system" in Table 2 of the paper. About 78 MB per model.
* `small0`, `small1`, `small2`: A smaller model, again trained on all data except GTZAN, with three different seeds. This corresponds to "smaller model" in Table 2 of the paper. About 8.1 MB per model.
* `single_final0`, `single_final1`, `single_final2`: Our main model, trained on the single split described in Section 4.1 of the paper, with three different seeds. This corresponds to "Our system" in Table 3 of the paper. About 78 MB per model.
* `fold0`, `fold1`, `fold2`, `fold3`, `fold4`, `fold5`, `fold6`, `fold7`: Our main model, trained in the 8-fold cross-validation setting with a single seed per fold. This corresponds to "Our" in Table 1 of the paper. About 78 MB per model.
Other models, available mainly for result reproducibility:
* `hung0`, `hung1`, `hung2`: A model trained on all the data used by the "Modeling Beats and Downbeats with a Time-Frequency Transformer" system by Hung et al. (except GTZAN dataset), with three different seeds. This corresponds to "limited to data of [10]" in Table 2 of the paper.
* the other models used for the ablation studies in Table 3, all trained with 3 seeds on the single split described in Section 4.1 of the paper:
* `single_notempoaug0`, `single_notempoaug1`, `single_notempoaug2`
* `single_nosumhead0`, `single_nosumhead1`, `single_nosumhead2`
* `single_nomaskaug0`, `single_nomaskaug1`, `single_nomaskaug2`
* `single_nopartialt0`, `single_nopartialt1`, `single_nopartialt2`
* `single_noshifttol0`, `single_noshifttol1`, `single_noshifttol2`
* `single_nopitchaug0`, `single_nopitchaug1`, `single_nopitchaug2`
* `single_noshifttolnoweights0`, `single_noshifttolnoweights1`, `single_noshifttolnoweights0`
Please be aware that the results may be unfairly good if you run inference on any file from the training datasets. For example, an evaluation with `final*` or `small*` can only be performed fairly on GTZAN or other datasets we didn't consider in our paper.
If you need to run an evaluation on some datasets we used other than GTZAN, consider targeting the validation part of the single split (with `single_final*`), or of the 8-fold cross-validation (with `fold*`).
All the models are provided as PyTorch Lightning checkpoints, stripped of the optimizer state to reduce their size. This is useful for reproducing the paper results or verifying the hyperparameters (stored in the checkpoint under `hyper_parameters` and `datamodule_hyper_parameters`).
During inference, PyTorch Lighting is not used, and the checkpoints are converted and loaded into vanilla PyTorch modules.
## Data
### Annotations
All annotations we used to train our models are available [in a separate GitHub repo](https://github.com/CPJKU/beat_this_annotations). Note that if you want to obtain the exact paper results, you should use [version 1.0](https://github.com/CPJKU/beat_this_annotations/releases/tag/v1.0). Other releases with corrected annotations may be published in the future.
To use the annotations for training or evaluation, you first need to download and extract or clone the annotations repo to `data/annotations`:
```bash
mkdir -p data
git clone https://github.com/CPJKU/beat_this_annotations data/annotations
# cd data/annotations; git checkout v1.0 # optional
```
### Spectrograms
The spectrograms used for training are released [as a Zenodo dataset](https://zenodo.org/records/13922116). They are distributed as a separate .zip file per dataset, each holding a .npz file with the spectrograms. For evaluation of the test set, download `gtzan.zip`; for training and evaluation of the validation set, download all (except `beat_this_annotations.zip`). Extract all .zip files into `data/audio/spectrograms`, so that you have, for example, `data/audio/spectrograms/gtzan.npz`. As an alternative, the code also supports directories of .npy files such as `data/audio/spectrograms/gtzan/gtzan_blues_00000/track.npy`, which you can obtain by unzipping `gtzan.npz`.
### Recreating spectrograms
If you have access to the original audio files, or want to add another dataset, create a text file `data/audio_paths.tsv` that has, on each line, the name of a dataset, a tab character, and the path to the audio directory. The corresponding annotations must also be present under `data/annotations`. Install pandas and pedalboard:
```bash
pip install pandas pedalboard
```
Then run:
```bash
python launch_scripts/preprocess_audio.py
```
It will create monophonic 22 kHz wave files in `data/audio/mono_tracks`, convert those to spectrograms in `data/audio/spectrograms`, and create spectrogram bundles. Intermediary files are kept and will not be recreated when rerunning the script.
## Reproducing metrics from the paper
### Requirements
In addition to the [inference requirements](#requirements), computing evaluation metrics requires installing PyTorch Lightning, Pandas, and `mir_eval` (the latter from source, as the current version on PyPI only supports numpy<1.20).
```bash
pip install pytorch_lightning pandas
pip install https://github.com/mir-evaluation/mir_eval/archive/main.zip
```
You must also obtain and set up the annotations and spectrogram datasets [as indicated above](#data). Specifically, the GTZAN dataset suffices for commands that include `--datasplit test`, while all other datasets are required for commands that include `--datasplit val`.
### Command line
#### Compute results on the test set (GTZAN) corresponding to Table 2 in the paper.
Main results for our system:
```bash
python launch_scripts/compute_paper_metrics.py --models final0 final1 final2 --datasplit test
```
Smaller model:
```bash
python launch_scripts/compute_paper_metrics.py --models small0 small1 small2 --datasplit test
```
Hung data:
```bash
python launch_scripts/compute_paper_metrics.py --models hung0 hung1 hung2 --datasplit test
```
With DBN (this requires installing the madmom package):
```bash
python launch_scripts/compute_paper_metrics.py --models final0 final1 final2 --datasplit test --dbn
```
#### Compute 8-fold cross-validation results, corresponding to Table 1 in the paper.
```bash
python launch_scripts/compute_paper_metrics.py --models fold0 fold1 fold2 fold3 fold4 fold5 fold6 fold7 --datasplit val --aggregation-type k-fold
```
#### Compute ablation studies on the validation set of the single split, correponding to Table 3 in the paper.
Our system:
```bash
python launch_scripts/compute_paper_metrics.py --models single_final0 single_final1 single_final2 --datasplit val
```
No sum head:
```bash
python launch_scripts/compute_paper_metrics.py --models single_nosumhead0 single_nosumhead1 single_nosumhead2 --datasplit val
```
No tempo augmentation:
```bash
python launch_scripts/compute_paper_metrics.py --models single_notempoaug0 single_notempoaug1 single_notempoaug2 --datasplit val
```
No mask augmentation:
```bash
python launch_scripts/compute_paper_metrics.py --models single_nomaskaug0 single_nomaskaug1 single_nomaskaug2 --datasplit val
```
No partial transformers:
```bash
python launch_scripts/compute_paper_metrics.py --models single_nopartialt0 single_nopartialt1 single_nopartialt2 --datasplit val
```
No shift tolerance:
```bash
python launch_scripts/compute_paper_metrics.py --models single_noshifttol0 single_noshifttol1 single_noshifttol2 --datasplit val
```
No pitch augmentation:
```bash
python launch_scripts/compute_paper_metrics.py --models single_nopitchaug0 single_nopitchaug1 single_nopitchaug2 --datasplit val
```
No shift tolerance and no weights:
```bash
python launch_scripts/compute_paper_metrics.py --models single_noshifttolnoweights0 single_noshifttolnoweights1 single_noshifttolnoweights2 --datasplit val
```
## Training
### Requirements
The training requirements match the [evaluation requirements](#requirements-1) for the validation set. All 16 datasets and annotations must be [correctly set up](#data).
### Command line
#### Train models listed in Table 2 in the paper.
Main results for our system (final0, final1, final2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-val
done
```
Smaller model (small0, small1, small2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-val --transformer-dim=128
done
```
Hung data (hung0, hung1, hung2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-val --hung-data
done
```
#### Train models with 8-fold cross-validation, corresponding to Table 1 in the paper.
```bash
for fold in {0..7}; do
python launch_scripts/train.py --fold=$fold
done
```
#### Train models for the ablation studies, corresponding to Table 3 in the paper.
Our system (single_final0, single_final1, single_final2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed
done
```
No sum head (single_nosumhead0, single_nosumhead1, single_nosumhead2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-sum-head
done
```
No tempo augmentation (single_notempoaug0, single_notempoaug1, single_notempoaug2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-tempo-augmentation
done
```
No mask augmentation (single_nomaskaug0, single_nomaskaug1, single_nomaskaug2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-mask-augmentation
done
```
No partial transformers (single_nopartialt0, single_nopartialt1, single_nopartialt2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-partial-transformers
done
```
No shift tolerance (single_noshifttol0, single_noshifttol1, single_noshifttol2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --loss weighted_bce
done
```
No pitch augmentation (single_nopitchaug0, single_nopitchaug1, single_nopitchaug2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --no-pitch-augmentation
done
```
No shift tolerance and no weights (single_noshifttolnoweights0, single_noshifttolnoweights1, single_noshifttolnoweights2):
```bash
for seed in 0 1 2; do
python launch_scripts/train.py --seed=$seed --loss bce
done
```
## Reusing the loss
To reuse our shift-invariant binary cross-entropy loss, just copy out the `ShiftTolerantBCELoss` class from [`loss.py`](beat_this/model/loss.py), it does not have any dependencies.
## Reusing the model
To reuse the BeatThis model, you have multiple options:
### From the package
When installing the `beat_this` package, you can directly import the model class:
```
from beat_this.model.beat_tracker import BeatThis
```
Instantiating this class will give you an untrained model from spectrograms to frame-wise beat and downbeat logits. For a pretrained model, use `load_model`:
```
from beat_this.inference import load_model
beat_this = load_model('final0', device='cuda')
```
### From torch.hub
To quickly try the model without installing the package, just install the [requirements for inference](#requirements) and do:
```
import torch
beat_this = torch.hub.load('CPJKU/beat_this', 'beat_this', 'final0', device='cuda')
```
### Copy and paste
To copy the BeatThis model into your own project, you will need the [`beat_tracker.py`](beat_this/model/beat_tracker.py) and [`roformer.py`](beat/this/model/roformer.py) files. If you remove the `BeatThis.state_dict()` and `BeatThis._load_from_state_dict()` methods that serve as a workaround for compiled models, then there are no other internal dependencies, only external dependencies (`einops`, `rotary-embedding-torch`).
## Citation
```bibtex
@inproceedings{foscarin2024beatthis,
author = {Francesco Foscarin and Jan Schl{\"u}ter and Gerhard Widmer},
title = {Beat this! Accurate beat tracking without {DBN} postprocessing},
year = 2024,
month = nov,
booktitle = {Proceedings of the 25th International Society for Music Information Retrieval Conference (ISMIR)},
address = {San Francisco, CA, United States},
}
```
================================================
FILE: beat_this/__init__.py
================================================
================================================
FILE: beat_this/cli.py
================================================
#!/usr/bin/env python3
"""
Beat This! command line inference tool.
"""
import argparse
import sys
from pathlib import Path
import numpy as np
import torch
try:
import tqdm
except ImportError:
tqdm = None
from beat_this.inference import File2File, load_audio
from beat_this.utils import save_beat_tsv
def get_parser():
parser = argparse.ArgumentParser(
description="Detects beats in given audio files with a Beat This! model."
)
parser.add_argument(
"inputs",
type=str,
nargs="+",
help="An audio file to process, or a directory of such files. Can be given multiple times.",
)
parser.add_argument(
"--model",
type=str,
help="Name, path or URL of checkpoint to use, will be downloaded if needed (default:%(default)s).",
default="final0",
)
parser.add_argument(
"--output",
"-o",
type=str,
default=None,
help="Output file name for a single input file, or output directory for multiple input files. If omitted, outputs are saved next to each input file by replacing or appending a suffix (see --suffix and --append).",
)
parser.add_argument(
"--suffix",
"-s",
type=str,
default=".beats",
help="Suffix for output file names (default: %(default)s). Also see --append. Ignored if an explicit output file name is given.",
)
parser.add_argument(
"--append",
action="store_true",
help="If given, append suffix to output file names instead of replacing the existing suffix. Ignored if an explicit output file name is given.",
)
parser.add_argument(
"--skip-existing",
action="store_true",
help="If given, do not overwrite existing output files, but skip them.",
)
parser.add_argument(
"--touch-first",
action="store_true",
help="If given, create empty output file before processing. Combined with --skip-existing, allows to run multiple processes in parallel on the same set of files.",
)
parser.add_argument(
"--dbn",
default=False,
action=argparse.BooleanOptionalAction,
help="Override the option to use madmom's postprocessing DBN.",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
help="Which GPU to use (not the number of GPUs), or -1 for CPU. Ignored if CUDA is not available. (default: %(default)s)",
)
parser.add_argument(
"--float16",
action="store_true",
help="If given, uses half precision floating point arithmetics. Required for flash attention on GPU. (default: %(default)s)",
)
parser.add_argument(
"--activations",
action="store_true",
help="If given, saves the raw activations with a .npy suffix.",
)
return parser
def derive_output_path(input_path, suffix, append, output=None, parent=None):
"""
Determine the output file name for `input_path` using the given
suffix. If given, `output` is the base directory for outputs, and
`parent` is the directory that was given on the command line.
"""
# output directory
if output is None:
output_path = input_path
else:
if parent is not None:
input_path = input_path.relative_to(parent)
else:
input_path = input_path.name
output_path = output / input_path
# suffix
if append:
return output_path.parent / (output_path.name + suffix)
else:
return output_path.with_suffix(suffix)
def run(
inputs,
model,
output,
suffix,
append,
skip_existing,
touch_first,
dbn,
gpu,
float16,
activations,
):
# determine device
if torch.cuda.is_available() and gpu >= 0:
device = torch.device(f"cuda:{gpu}")
else:
device = torch.device("cpu")
# prepare model
file2file = File2File(model, device, float16, dbn)
if activations:
def process(audiofile, outfile):
wav, sr = load_audio(audiofile)
spect = file2file.signal2spect(wav, sr)
beat_logits, downbeat_logits = file2file.spect2frames(spect)
np.save(
outfile.with_suffix(".npy"),
np.vstack([beat_logits.cpu().numpy(), downbeat_logits.cpu().numpy()]),
)
beats, downbeats = file2file.frames2beats(beat_logits, downbeat_logits)
save_beat_tsv(beats, downbeats, outfile)
else:
process = file2file
# process inputs
inputs = [Path(item) for item in inputs]
if output is not None:
output = Path(output)
if len(inputs) == 1 and not inputs[0].is_dir():
# special case: single input file
if output is None or output.is_dir():
output = derive_output_path(inputs[0], suffix, append, output)
process(inputs[0], output)
else:
# multiple inputs: first collect tasks so we can have a progress bar
tasks = []
for item in inputs:
if item.is_dir():
for fn in item.rglob("*"):
if not fn.name.endswith(suffix) and not fn.is_dir():
output_path = derive_output_path(
fn, suffix, append, output, parent=item
)
if not skip_existing or not output_path.exists():
tasks.append((fn, output_path))
else:
tasks.append((item, derive_output_path(item, suffix, append, output)))
# then process all of them
if tqdm is not None:
tasks = tqdm.tqdm(tasks)
for item, output in tasks:
if touch_first:
try:
output.touch(exist_ok=not skip_existing)
except FileExistsError:
continue
elif skip_existing and output.exists():
continue
try:
process(item, output)
except Exception:
print(
f'Could not process "{item}". Rerun with this file alone for details.',
file=sys.stderr,
)
def main():
run(**vars(get_parser().parse_args()))
if __name__ == "__main__":
sys.exit(main())
================================================
FILE: beat_this/dataset/__init__.py
================================================
from beat_this.dataset.dataset import BeatDataModule
================================================
FILE: beat_this/dataset/augment.py
================================================
import numpy as np
import torch
def augment_pitchtempo(item, augmentations):
"""
Apply a randomly chosen pitch or tempo augmentation to the item.
Parameters:
item: dict
A dictionary representing the item to be augmented. It should contain the following keys:
- 'spect_path': The path to the the unaugmented spectrogram file.
If pitch or tempo augmentation is applied, the 'spect_path' key will be updated.
augmentations: dict
A dictionary containing the augmentations to be applied. It can contain either or both of the following keys:
- 'pitch': A dictionary with 'min' and 'max' keys specifying the range of pitch shifting in semitones.
- 'tempo': A dictionary with 'min' and 'max' keys specifying the range of time stretching factors.
Returns:
item: dict
The item after applying the augmentation. If a pitch or tempo augmentation was applied, the 'spect_path' key
and the annotations will be updated.
"""
# Handle pitch and tempo augmentations
if "pitch" in augmentations and "tempo" in augmentations:
# if both pitch and tempo are enabled, pick one of them
if np.random.randint(2) == 0:
# pitch
item = augment_pitch(item, augmentations["pitch"])
else:
# tempo
item = augment_tempo(item, augmentations["tempo"])
elif "pitch" in augmentations:
item = augment_pitch(item, augmentations["pitch"])
elif "tempo" in augmentations:
item = augment_tempo(item, augmentations["tempo"])
return item
def augment_pitch(item, pitch_params):
"""Apply pitch shifting to the item."""
semitones = np.random.randint(pitch_params["min"], pitch_params["max"] + 1)
item = shift_filename(item, semitones)
item = shift_annotations(item, semitones)
return item
def augment_tempo(item, tempo_params):
"""Apply time stretching to the item."""
percentage = np.random.choice(
np.arange(tempo_params["min"], tempo_params["max"] + 1, tempo_params["stride"])
)
item = stretch_filename(item, percentage)
item = stretch_annotations(item, percentage)
return item
def stretch_annotations(item, percentage):
"""Apply time stretching to the item's annotations."""
if not percentage:
return item
# percentage is the amount by which the *tempo* changes
factor = 1.0 + percentage / 100
item = dict(item)
item["beat_time"] = item["beat_time"] / factor
return item
def shift_annotations(item, semitones):
"""Apply pitch shifting to the item's annotations."""
return item
def stretch_filename(item, percentage):
"""Derive filename of precomputed time stretched version."""
spect_path = item["spect_path"]
if percentage:
stem = spect_path.stem + f"_ts{percentage}"
spect_path = spect_path.with_stem(stem)
return {**item, "spect_path": spect_path}
def shift_filename(item, semitones):
"""Derive filename of precomputed pitch shifted version."""
spect_path = item["spect_path"]
if semitones:
stem = spect_path.stem + f"_ps{semitones}"
spect_path = spect_path.with_stem(stem)
return {**item, "spect_path": spect_path}
def number_of_precomputed_augmentations(augmentations):
"""Return the number of augmentations that correspond to a precomputed file."""
counter = 1
for method, params in augmentations.values():
if method in ("pitch"):
counter += params["max"] - params["min"]
elif method in ("tempo"):
counter += (params["max"] - params["min"]) // params["stride"]
return counter
def precomputed_augmentation_filenames(augmentations, ext="npy"):
"""Return the filenames of the precomputed augmentations.
Parameters:
augmentations: dict
A dictionary containing the augmentations to be applied. It can contain either or both of the following keys:
- 'pitch': A dictionary with 'min' and 'max' keys specifying the range (including boundaries) of pitch shifting in semitones.
- 'tempo': A dictionary with 'min' and 'max' keys specifying the range (including boundaries) of time stretching factors; and a 'stride' key specifying the step size.
"""
filenames = [f"track.{ext}"]
for method, params in augmentations.items():
if method == "pitch":
for semitones in range(params["min"], params["max"] + 1):
if semitones == 0:
continue
filenames.append(f"track_ps{semitones}.{ext}")
elif method == "tempo":
for percentage in range(params["min"], params["max"] + 1, params["stride"]):
if percentage == 0:
continue
filenames.append(f"track_ts{percentage}.{ext}")
return filenames
def augment_mask_(spect, augmentations: dict, fps: int):
"""
Apply the given masking operations to the spectrogram. The spectrogram is modified in place.
Parameters:
spect: ndarray
The input spectrogram to which the mask will be applied. It is a 2D array where the first dimension
represents time frames and the second dimension represents frequency bins.
augmentations: dict
A dictionary containing all the augmentations. If there is no "mask" key, this function returns the
unmodified spectrogram. If "mask" key is present, the value is another dictionary which must include
the following keys:
- 'kind': The type of mask to apply. Choices: 'permute' and 'zero'.
- 'min_count' and 'max_count': The minimum and maximum number of times the mask should be applied.
- 'min_len' and 'max_len': The minimum and maximum length of the mask, expressed in seconds.
- 'min_parts' and 'max_parts': The minimum and maximum number of parts in which each masked section is segmented.
These are then randomly reordered. If 'kind'='permute' this parameter is not used.
fps: int
The frames per second of the audio. This is used to convert 'min_len' and 'max_len' from seconds to frames.
Returns:
spect: ndarray
The spectrogram after applying the mask.
"""
if "mask" in augmentations:
mask_params = augmentations["mask"]
count = np.random.randint(
mask_params["min_count"], mask_params["max_count"] + 1
)
# convert min_len and max_len in frames
min_len = int(mask_params["min_len"] * fps)
max_len = int(mask_params["max_len"] * fps)
# apply the masking a number of time specified by count
for _ in range(count):
length = np.random.randint(min_len, max_len + 1)
start = np.random.randint(0, len(spect) - length)
apply_mask_excerpt(
spect[start : start + length],
mask_params["kind"],
mask_params["min_parts"],
mask_params["max_parts"],
)
return spect
def apply_mask_excerpt(excerpt, kind, min_parts, max_parts):
"""Apply a mask operation of the given kind in-place to the given tensor."""
if kind == "permute":
num_parts = np.random.randint(min_parts, max_parts + 1)
choices = len(excerpt)
num_parts = min(num_parts, choices + 1)
positions = np.random.choice(choices, num_parts - 1, replace=False)
positions.sort()
if isinstance(excerpt, np.ndarray):
parts = np.split(excerpt, positions)
else:
parts = (
[excerpt[: positions[0]]]
+ [excerpt[a:b] for a, b in zip(positions[:-1], positions[1:])]
+ [excerpt[positions[-1] :]]
)
parts = [parts[idx] for idx in np.random.permutation(num_parts)]
if isinstance(excerpt, np.ndarray):
excerpt[:] = np.concatenate(parts)
else:
excerpt[:] = torch.cat(parts)
elif kind == "zero":
excerpt[:] = 0
else:
raise ValueError(f"Unsupported mask operation: {kind}")
================================================
FILE: beat_this/dataset/dataset.py
================================================
import concurrent.futures
import itertools
import json
import re
from pathlib import Path
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
from beat_this.dataset.augment import (
augment_mask_,
augment_pitchtempo,
precomputed_augmentation_filenames,
)
from beat_this.utils import index_to_framewise
from .mmnpz import MemmappedNpzFile
class BeatTrackingDataset(Dataset):
"""
A PyTorch Dataset for beat tracking. This dataset loads preprocessed spectrograms and beat annotations
from a given data folder and provides them for training or evaluation.
Args:
item_names (list of str): A list of dataset items such as "gtzan/gtzan_rock_00099".
data_folder (Path or str): The base folder where the data is stored.
spect_fps (int, optional): The frames per second of the spectrograms. Defaults to 50.
train_length (int, optional): The length of the training sequences in frames. If None the entire piece is used. Defaults to 1500.
deterministic (bool, optional): If True, the dataset always returns the same sequence for a given index.
Defaults to False.
augmentations (dict, optional): A dictionary of data augmentations to apply. Possible keys are "tempo", "pitch", and "mask". Defaults to an empty dictionary.
"""
def __init__(
self,
item_names: list[str],
data_folder,
spect_fps=50,
train_length=1500,
deterministic=False,
augmentations={},
length_based_oversampling_factor=0,
):
self.spect_basepath = data_folder / "audio" / "spectrograms"
self.annotation_basepath = data_folder / "annotations"
self.fps = spect_fps
self.train_length = train_length
self.deterministic = deterministic
self.augmentations = augmentations
self.length_based_oversampling_factor = length_based_oversampling_factor
datasets = sorted(set(name.split("/", 1)[0] for name in item_names))
# load dataset info
self.dataset_info = self._load_dataset_infos(datasets)
# load .npz spectrogram bundles, if any
self.spects = self._load_spect_bundles(datasets)
# load the annotations in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
items = executor.map(self._load_dataset_item, item_names)
items = [item for item in items if item is not None]
if self.length_based_oversampling_factor and self.train_length is not None:
# oversample the dataset according to the audio lengths, so that long pieces are sampled more often
oversampled_items = []
for item in items:
oversampling_factor = np.round(
self.length_based_oversampling_factor
* len(self._get_spect(item))
/ self.train_length
).astype(int)
oversampling_factor = max(oversampling_factor, 1)
oversampled_items.extend(itertools.repeat(item, oversampling_factor))
print(
f"Training set oversampled from {len(items)} to {len(oversampled_items)} excerpts."
)
items = oversampled_items
self.items = items
def _load_dataset_infos(self, datasets):
dataset_info = {}
for dataset in datasets:
with open(self.annotation_basepath / dataset / "info.json") as f:
dataset_info[dataset] = json.load(f)
return dataset_info
def _load_spect_bundles(self, datasets):
spects = {}
for dataset in datasets:
npz_file = (self.spect_basepath / dataset).with_suffix(".npz")
if npz_file.exists():
spects[dataset] = MemmappedNpzFile(npz_file)
return spects
def _load_dataset_item(self, item_name):
# stop if not all the augmented audio files are there
dataset, remainder = item_name.split("/", 1)
for aug_filename in precomputed_augmentation_filenames(self.augmentations):
if (f"{remainder}/{aug_filename[:-4]}") not in self.spects.get(
dataset, ()
) and not (self.spect_basepath / item_name / aug_filename).exists():
print(
f"Skipping {item_name} because not all necessary spectrograms are there."
)
return
# load beat and produce a default if beat values are not found
dataset, stem = item_name.split("/", 1)
annotation_path = (
self.annotation_basepath
/ dataset
/ "annotations"
/ "beats"
/ (stem + ".beats")
)
beat_annotation = np.loadtxt(annotation_path)
if beat_annotation.ndim == 2:
beat_time = beat_annotation[:, 0]
beat_value = beat_annotation[:, 1].astype(int)
else:
beat_time = beat_annotation
beat_value = np.zeros_like(beat_time, dtype=np.int32)
# stop if the annotations that are supposed to be there are not there
if self.dataset_info[dataset]["has_downbeats"]:
if beat_annotation.ndim != 2:
print(
f"Skipping {item_name} because it has {beat_annotation.ndim} columns but downbeat is supposed to be there."
)
return
# create a downbeat mask to handle the case where the downbeat is not annotated
downbeat_mask = self.dataset_info[dataset]["has_downbeats"]
# take care of different subsections of rwc for the dataset name
if dataset == "rwc":
dataset = "rwc_" + stem.split("_", 2)[1]
return {
"spect_path": Path(item_name) / "track.npy",
"beat_time": beat_time,
"beat_value": beat_value,
"downbeat_mask": downbeat_mask,
"dataset": dataset,
}
def _get_spect(self, item):
try:
dataset, filename = str(item["spect_path"]).split("/", 1)
spect = self.spects[dataset][filename[:-4]]
except KeyError:
spect = np.load(self.spect_basepath / item["spect_path"], mmap_mode="r")
return spect
def get_frame_count(self, index):
"""Return number of frames of given item."""
return len(self._get_spect(self.items[index]))
def get_beat_count(self, index):
"""Return number of beats (including downbeats) of given item."""
return len(self.items[index]["beat_time"])
def get_downbeat_count(self, index):
"""Return number of downbeats of given item."""
return (self.items[index]["beat_value"] == 1).sum()
def __len__(self):
return len(self.items)
def __getitem__(self, index):
if isinstance(index, (int, np.int64)): # when index is a single int
item = self.items[index]
# select a pitch shift and time stretch
item = augment_pitchtempo(item, self.augmentations)
# load spectrogram
spect = self._get_spect(item)
# define the excerpt to use
original_length = len(spect)
if self.train_length is not None:
longer = original_length - self.train_length
else:
longer = 0
if longer > 0: # if the piece is longer than the desired length
if self.deterministic:
# select the middle of the excerpt
start_frame = longer // 2
else:
start_frame = np.random.randint(0, longer)
end_frame = start_frame + self.train_length
else:
start_frame = 0
end_frame = original_length
# obtain a view of the excerpt
spect = spect[start_frame:end_frame]
if "mask" in self.augmentations:
# copy the spectrogram and apply mask augmentation
spect = np.copy(spect)
spect = augment_mask_(spect, self.augmentations, self.fps)
else:
# only ensure we have a writeable array (so PyTorch is happy)
spect = np.require(spect, requirements="W")
# prepare annotations
(
framewise_truth_beat,
framewise_truth_downbeat,
truth_orig_beat,
truth_orig_downbeat,
) = prepare_annotations(item, start_frame, end_frame, self.fps)
# restructure the item dict with the correct training information
item = {
"spect": spect,
"spect_path": str(item["spect_path"]),
"dataset": item["dataset"],
"start_frame": start_frame,
"truth_beat": framewise_truth_beat,
"truth_downbeat": framewise_truth_downbeat,
"downbeat_mask": torch.as_tensor(item["downbeat_mask"]),
"padding_mask": (
np.ones(self.train_length, dtype=bool)
if self.train_length is not None
else np.ones(original_length, dtype=bool)
),
"truth_orig_beat": truth_orig_beat,
"truth_orig_downbeat": truth_orig_downbeat,
}
# pad all framewise tensors if needed
if longer < 0:
item["spect"] = np.pad(
item["spect"], [(0, -longer), (0, 0)], constant_values=0
)
for k in "truth_beat", "truth_downbeat":
item[k] = np.pad(item[k], [(0, -longer)], constant_values=0)
item["padding_mask"][longer:] = 0
return item
else: # when index is a list of ints
return [self[i] for i in index]
class BeatDataModule(pl.LightningDataModule):
"""
A PyTorch Lightning DataModule for beat tracking. This DataModule handles the loading and preprocessing of the
BeatTrackingDataset and prepares it for use with a PyTorch Lightning model.
It can produce cross-validation or single train/val/test splits.
Args:
data_dir (Path or str): The parent directory where the data (spectrograms and beat labels) is stored.
batch_size (int, optional): The size of the batches to be generated by the DataLoader. Defaults to 8.
train_length (int, optional): The length of the subsequences in frames. If None, the entire pieces are returner. Defaults to 1500.
num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 20.
augmentations (dict, optional): A dictionary of data augmentations to apply. Defaults to {"pitch": {"min": -5, "max": 6}, "time": {"min": -20, "max": 20, "stride": 4}}.
test_dataset (str, optional): The name of the dataset to use for testing. Defaults to "gtzan".
hung_data (bool, optional): If True, only use the datasets from the Hung et al. paper for training; validation is still on all datasets. Defaults to False.
no_val (bool, optional): If True, train on all train+val data and do not use a validation set; for compatibility reason, the validation metrics are still computed, but are not meaningful. Defaults to False.
spect_fps (int, optional): The frames per second of the spectrograms. Defaults to 50.
length_based_oversampling_factor (int, optional): The factor by which to oversample the train dataset based on sequence length. Defaults to 0.
fold (int, optional): The fold number for cross-validation. If None, the single split is used. Defaults to None.
predict_datasplit (str, optional): The split to use for prediction. Prediction dataset is always full pieces. Defaults to "test".
"""
def __init__(
self,
data_dir,
batch_size=8,
train_length=1500,
num_workers=20,
augmentations={
"pitch": {"min": -5, "max": 6},
"tempo": {"min": -20, "max": 20, "stride": 4},
},
test_dataset="gtzan",
hung_data=False,
no_val=False,
spect_fps=50,
length_based_oversampling_factor=0,
fold=None,
predict_datasplit="test",
):
super().__init__()
self.save_hyperparameters()
self.initialized = {}
# remember all arguments
self.data_dir = Path(data_dir)
self.batch_size = batch_size
self.train_length = train_length
self.num_workers = num_workers
if not set(augmentations.keys()).issubset({"mask", "pitch", "tempo"}):
raise ValueError(f"Unsupported augmentations: {augmentations.keys()}")
self.augmentations = augmentations
self.test_set_name = test_dataset
self.hung_data = hung_data
self.no_val = no_val
self.spect_fps = spect_fps
self.length_based_oversampling_factor = length_based_oversampling_factor
self.fold = fold
self.predict_datasplit = predict_datasplit
def setup(self, stage):
if self.initialized.get(stage, False):
return
# set up the paths
annotation_dir = self.data_dir / "annotations"
# load train/val splits
if stage in ("fit", "validate"):
self.val_items = []
self.train_items = []
split_file = "8-folds.split" if self.fold is not None else "single.split"
for dataset_dir in annotation_dir.iterdir():
if not dataset_dir.is_dir() or not (dataset_dir / split_file).exists():
continue
dataset = dataset_dir.name
if dataset == self.test_set_name:
continue
split = pd.read_csv(
dataset_dir / split_file,
header=None,
names=["piece", "part"],
sep="\t",
)
if self.fold is not None:
# CV: use given fold for validation, rest for training
self.val_items.extend(
f"{dataset}/{stem}"
for stem in split.piece[split.part == self.fold]
)
self.train_items.extend(
f"{dataset}/{stem}"
for stem in split.piece[split.part != self.fold]
)
else:
# single split: marked as val and train
self.val_items.extend(
f"{dataset}/{stem}" for stem in split.piece[split.part == "val"]
)
self.train_items.extend(
f"{dataset}/{stem}"
for stem in split.piece[split.part == "train"]
)
if self.no_val:
# Train on all available data (excluding the test set).
# For compatibility, validation metrics are still computed
# on the original validation set now included in training.
self.train_items.extend(self.val_items)
if self.hung_data:
# Use the training datasets from MODELING BEATS AND DOWNBEATS
# WITH A TIME-FREQUENCY TRANSFORMER (for comparability, the
# validation set stays the same, with all datasets).
regexp = re.compile(
"^(hainsworth/|ballroom/|hjdb/|beatles/|rwc/rwc_popular|simac/|smc/|harmonix/|).*$"
)
self.train_items = [
item for item in self.train_items if regexp.match(item)
]
self.val_items.sort()
self.train_items.sort()
# load validation set
if stage in ("fit", "validate"):
self.val_dataset = BeatTrackingDataset(
self.val_items,
deterministic=True,
augmentations={},
train_length=self.train_length,
data_folder=self.data_dir,
spect_fps=self.spect_fps,
)
print(
"Validation set:",
len(self.val_dataset),
"items from:",
*sorted(set(item.split("/", 1)[0] for item in self.val_items)),
)
self.initialized["validate"] = True
# load training set
if stage == "fit":
self.train_dataset = BeatTrackingDataset(
self.train_items,
deterministic=False,
augmentations=self.augmentations,
train_length=self.train_length,
data_folder=self.data_dir,
spect_fps=self.spect_fps,
length_based_oversampling_factor=self.length_based_oversampling_factor,
)
print(
"Training set:",
len(self.train_dataset),
"items from:",
*sorted(set(item.split("/", 1)[0] for item in self.train_items)),
)
self.initialized["fit"] = True
# load test set
if stage == "test":
test_annotations_dir = (
annotation_dir / self.test_set_name / "annotations" / "beats"
)
self.test_items = sorted(
f"{self.test_set_name}/{item.stem}"
for item in test_annotations_dir.glob("*.beats")
)
self.test_dataset = BeatTrackingDataset(
self.test_items,
deterministic=True,
augmentations={},
train_length=None,
data_folder=self.data_dir,
spect_fps=self.spect_fps,
)
print(
"Test set:", len(self.test_dataset), "items from:", self.test_set_name
)
self.initialized["test"] = True
# load prediction set
if stage == "predict":
if self.predict_datasplit == "test":
self.setup("test")
# we can directly use the test dataset for predictions
self.predict_dataset = self.test_dataset
else:
if self.predict_datasplit == "train":
self.setup("fit")
items = self.train_items
elif self.predict_datasplit == "val":
self.setup("validate")
items = self.val_items
# for prediction, we want to use full items (train_length=None)
self.predict_dataset = BeatTrackingDataset(
items,
deterministic=True,
augmentations={},
train_length=None,
data_folder=self.data_dir,
spect_fps=self.spect_fps,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
num_workers=self.num_workers,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
pin_memory=True,
)
def val_dataloader(self):
# Warning: for performances, this only runs on the middle excerpt of the long pieces
# The paper results are computed after training in the predict script
return DataLoader(
self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers
)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=1, num_workers=self.num_workers)
def predict_dataloader(self):
return DataLoader(
self.predict_dataset, batch_size=1, num_workers=self.num_workers
)
def get_train_positive_weights(self, widen_target_mask=3):
"""
Computes the relation of negative targets to positive targets.
`widen_target_mask` reduces the number of negative targets by the given
factor times the number of positive targets (for ignoring a number of
frames around each positive label).
For example a `widen_target_mask` of 3 will ignore 7 frames, 3 for each side plus the central.
"""
# find the positive weight for the loss as a ratio between (down)beat and non-(down)beat annotation
dataset = self.train_dataset
all_frames = all_frames_db = 0
for item in dataset.items:
frames = len(dataset._get_spect(item))
all_frames += frames
if item["downbeat_mask"]:
all_frames_db += frames
beat_frames = sum(len(item["beat_value"]) for item in dataset.items)
downbeat_frames = sum(
(item["beat_value"] == 1).sum()
for item in dataset.items
if item["downbeat_mask"]
)
return {
"beat": int(
np.round(
(all_frames - beat_frames * (widen_target_mask * 2 + 1))
/ beat_frames
)
),
"downbeat": int(
np.round(
(all_frames_db - downbeat_frames * (widen_target_mask * 2 + 1))
/ downbeat_frames
)
),
}
def prepare_annotations(item, start_frame, end_frame, fps):
truth_bdb_time = item["beat_time"]
truth_bdb_value = item["beat_value"]
# convert beat time from seconds to frame
truth_bdb_frame = (truth_bdb_time * fps).round().astype(int)
# form annotations excerpt
# filter out the annotations that are earlier than the start and shift left
truth_bdb_frame -= start_frame
idx = np.searchsorted(truth_bdb_frame, 0)
truth_bdb_frame = truth_bdb_frame[idx:]
truth_bdb_value = truth_bdb_value[idx:]
# filter out the annotations that are later than the end
idx = np.searchsorted(truth_bdb_frame, end_frame - start_frame)
truth_bdb_frame = truth_bdb_frame[:idx]
truth_bdb_value = truth_bdb_value[:idx]
# create beat and downbeat separated annotations
truth_beat = truth_bdb_frame
truth_downbeat = truth_bdb_frame[truth_bdb_value == 1]
# transform beat downbeat to frame-wise annotations
framewise_truth_beat = index_to_framewise(truth_beat, end_frame - start_frame)
framewise_truth_downbeat = index_to_framewise(
truth_downbeat, end_frame - start_frame
)
# create orig beat, downbeat annotations for unquantized evaluation
truth_orig_beat = item["beat_time"]
truth_orig_downbeat = truth_bdb_time[
item["beat_value"] == 1
] # (use the full beat_value)
# filter out the annotations that are outside the excerpt, and shift them left to the excerpt time
truth_orig_beat = truth_orig_beat[
(truth_orig_beat >= start_frame / fps) & (truth_orig_beat < end_frame / fps)
] - (start_frame / fps)
truth_orig_downbeat = truth_orig_downbeat[
(truth_orig_downbeat >= start_frame / fps)
& (truth_orig_downbeat < end_frame / fps)
] - (start_frame / fps)
# convert to strings (trick to collate sequences of different lengths)
truth_orig_beat = truth_orig_beat.tobytes()
truth_orig_downbeat = truth_orig_downbeat.tobytes()
return (
framewise_truth_beat,
framewise_truth_downbeat,
truth_orig_beat,
truth_orig_downbeat,
)
================================================
FILE: beat_this/dataset/mmnpz.py
================================================
"""
Support for memory-mapping uncompressed .npz files.
"""
import struct
from collections.abc import Mapping
from zipfile import ZipFile
import numpy as np
class MemmappedNpzFile(Mapping):
"""
A dictionary-like object with lazy-loading of numpy arrays in the given
uncompressed .npz file. Upon construction, creates a memory map of the
full .npz file, returning views for the arrays within on request.
Attributes
----------
files : list of str
List of all uncompressed files in the archive with a ``.npy`` extension
(listed without the extension). These are supported as dictionary keys.
mmap : np.memmap
The memory map of the full .npz file.
arrays : dict
Preloaded or cached arrays.
Parameters
----------
fn : str or Path
The zipped archive to open.
cache : bool, optional
Whether to cache array objects in case they are requested again.
preload : bool, optional
Whether to precreate all array objects upon opening. Enforces caching.
"""
def __init__(self, fn: str, cache: bool = True, preload: bool = False):
with ZipFile(fn, mode="r") as f:
self._offsets = {
zinfo.filename[:-4]: (zinfo.header_offset, zinfo.file_size)
for zinfo in f.infolist()
if zinfo.filename.endswith(".npy") and zinfo.compress_type == 0
}
self.files = list(self._offsets.keys())
self.mmap = np.memmap(fn, mode="r")
self.cache = cache or preload
self.preload = preload
if self.preload:
self.arrays = {name: self.load(name) for name in self.files}
else:
self.arrays = {}
def load(self, name: str):
header_offset, file_size = self._offsets[name]
# parse lengths of local header file name and extra fields
# (ZipInfo is based on the global directory, not local header)
fn_len, extra_len = struct.unpack(
"<2H", self.mmap[header_offset + 26 : header_offset + 30]
)
# compute offset of start and end of data
npy_start = header_offset + 30 + fn_len + extra_len
npy_end = npy_start + file_size
# read NPY header
fp = MemoryviewIO(self.mmap)
fp.seek(npy_start)
version = np.lib.format.read_magic(fp)
np.lib.format._check_version(version)
shape, fortran, dtype = np.lib.format._read_array_header(fp, version)
# produce slice of memmap
data_start = fp.tell()
return (
self.mmap[data_start:npy_end]
.view(dtype=dtype)
.reshape(shape, order="F" if fortran else "C")
)
def close(self):
if hasattr(self, "mmap"):
del self.mmap
self.arrays = {}
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __iter__(self):
return iter(self.files)
def __len__(self):
return len(self.files)
def __getitem__(self, key: str):
if self.cache:
try:
return self.arrays[key]
except KeyError:
pass
array = self.load(key)
if self.cache:
self.arrays[key] = array
return array
def __contains__(self, key: str):
# Mapping.__contains__ calls __getitem__, which could be expensive
return key in self._offsets
class MemoryviewIO(object):
"""
Wraps an object supporting the buffer protocol to be a readonly file-like.
"""
def __init__(self, buffer):
self._buffer = memoryview(buffer).cast("B")
self._pos = 0
self.seekable = lambda: True
self.readable = lambda: True
self.writable = lambda: False
def seek(self, offset, whence=0):
if whence == 0:
self._pos = offset
elif whence == 1:
self._pos += offset
elif whence == 2:
self._pos = self._buffer.nbytes + offset
def read(self, size=-1):
data = self._buffer[
self._pos : self._pos + size if size >= 0 else None
].tobytes()
self._pos += len(data)
return data
def tell(self):
return self._pos
================================================
FILE: beat_this/inference.py
================================================
import inspect
import numpy as np
import soxr
import torch
import torch.nn.functional as F
from beat_this.model.beat_tracker import BeatThis
from beat_this.model.postprocessor import Postprocessor
from beat_this.preprocessing import LogMelSpect, load_audio
from beat_this.utils import replace_state_dict_key, save_beat_tsv
CHECKPOINT_URL = "https://cloud.cp.jku.at/public.php/dav/files/7ik4RrBKTS273gp"
def load_checkpoint(checkpoint_path: str, device: str | torch.device = "cpu") -> dict:
"""
Load a BeatThis checkpoint as a dictionary.
Args:
checkpoint_path (str, optional): The path to the checkpoint. Can be a local path, a URL, or a shortname.
device (torch.device or str): The device to load the model on.
Returns:
dict: The loaded checkpoint dictionary.
"""
try:
# try interpreting as local file name
weights_only = {"weights_only": True} if torch.__version__ >= "2" else {}
return torch.load(checkpoint_path, map_location=device, **weights_only)
except FileNotFoundError:
try:
if not (
str(checkpoint_path).startswith("https://")
or str(checkpoint_path).startswith("http://")
):
# interpret it as a name of one of our checkpoints
checkpoint_url = f"{CHECKPOINT_URL}/{checkpoint_path}.ckpt"
file_name = f"beat_this-{checkpoint_path}.ckpt"
else:
# try interpreting as a URL
checkpoint_url = checkpoint_path
file_name = None
return torch.hub.load_state_dict_from_url(
checkpoint_url,
file_name=file_name,
map_location=device,
)
except Exception:
raise ValueError(
"Could not load the checkpoint given the provided name",
checkpoint_path,
)
def load_model(
checkpoint_path: str | None = "final0", device: str | torch.device = "cpu"
) -> BeatThis:
"""
Load a BeatThis model from a checkpoint.
Args:
checkpoint_path (str, optional): The path to the checkpoint. Can be a local path, a URL, or a shortname.
device (torch.device or str): The device to load the model on.
Returns:
BeatThis: The loaded model.
"""
if checkpoint_path is not None:
checkpoint = load_checkpoint(checkpoint_path, device)
# Retrieve the model hyperparameters as it could be the small model
hparams = checkpoint["hyper_parameters"]
# Filter only those hyperparameters that apply to the model itself
hparams = {
k: v
for k, v in hparams.items()
if k in set(inspect.signature(BeatThis).parameters)
}
# Create the uninitialized model
model = BeatThis(**hparams)
# The PLBeatThis (LightningModule) state_dict contains the BeatThis
# state_dict under the "model." prefix; remove the prefix to load it
state_dict = replace_state_dict_key(checkpoint["state_dict"], "model.", "")
model.load_state_dict(state_dict)
else:
model = BeatThis()
return model.to(device).eval()
def zeropad(spect: torch.Tensor, left: int = 0, right: int = 0):
"""
Pads a tensor spectrogram matrix of shape (time x bins) with `left` frames in the beginning and `right` frames in the end.
"""
if left == 0 and right == 0:
return spect
else:
return F.pad(spect, (0, 0, left, right), "constant", 0)
def split_piece(
spect: torch.Tensor,
chunk_size: int,
border_size: int = 6,
avoid_short_end: bool = True,
):
"""
Split a tensor spectrogram matrix of shape (time x bins) into time chunks of `chunk_size` and return the chunks and starting positions.
The `border_size` is the number of frames assumed to be discarded in the predictions on either side (since the model was not trained on the input edges due to the max-pool in the loss).
To cater for this, the first and last chunk are padded by `border_size` on the beginning and end, respectively, and consecutive chunks overlap by `border_size`.
If `avoid_short_end` is true, the last chunk start is shifted left to ends at the end of the piece, therefore the last chunk can potentially overlap with previous chunks more than border_size, otherwise it will be a shorter segment.
If the piece is shorter than `chunk_size`, avoid_short_end is ignored and the piece is returned as a single shorter chunk.
Args:
spect (torch.Tensor): The input spectrogram tensor of shape (time x bins).
chunk_size (int): The size of the chunks to produce.
border_size (int, optional): The size of the border to overlap between chunks. Defaults to 6.
avoid_short_end (bool, optional): If True, the last chunk is shifted left to end at the end of the piece. Defaults to True.
"""
# generate the start and end indices
starts = np.arange(
-border_size, len(spect) - border_size, chunk_size - 2 * border_size
)
if avoid_short_end and len(spect) > chunk_size - 2 * border_size:
# if we avoid short ends, move the last index to the end of the piece - (chunk_size - border_size)
starts[-1] = len(spect) - (chunk_size - border_size)
# generate the chunks
chunks = [
zeropad(
spect[max(start, 0) : min(start + chunk_size, len(spect))],
left=max(0, -start),
right=max(0, min(border_size, start + chunk_size - len(spect))),
)
for start in starts
]
return chunks, starts
def aggregate_prediction(
pred_chunks: list,
starts: list,
full_size: int,
chunk_size: int,
border_size: int,
overlap_mode: str,
device: str | torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Aggregates the predictions for the whole piece based on the given prediction chunks.
Args:
pred_chunks (list): List of prediction chunks, where each chunk is a dictionary containing 'beat' and 'downbeat' predictions.
starts (list): List of start positions for each prediction chunk.
full_size (int): Size of the full piece.
chunk_size (int): Size of each prediction chunk.
border_size (int): Size of the border to be discarded from each prediction chunk.
overlap_mode (str): Mode for handling overlapping predictions. Can be 'keep_first' or 'keep_last'.
device (torch.device): Device to be used for the predictions.
Returns:
tuple: A tuple containing the aggregated beat predictions and downbeat predictions as torch tensors for the whole piece.
"""
if border_size > 0:
# cut the predictions to discard the border
pred_chunks = [
{
"beat": pchunk["beat"][border_size:-border_size],
"downbeat": pchunk["downbeat"][border_size:-border_size],
}
for pchunk in pred_chunks
]
# aggregate the predictions for the whole piece
piece_prediction_beat = torch.full((full_size,), -1000.0, device=device)
piece_prediction_downbeat = torch.full((full_size,), -1000.0, device=device)
if overlap_mode == "keep_first":
# process in reverse order, so predictions of earlier excerpts overwrite later ones
pred_chunks = reversed(list(pred_chunks))
starts = reversed(list(starts))
for start, pchunk in zip(starts, pred_chunks):
piece_prediction_beat[
start + border_size : start + chunk_size - border_size
] = pchunk["beat"]
piece_prediction_downbeat[
start + border_size : start + chunk_size - border_size
] = pchunk["downbeat"]
return piece_prediction_beat, piece_prediction_downbeat
def split_predict_aggregate(
spect: torch.Tensor,
chunk_size: int,
border_size: int,
overlap_mode: str,
model: torch.nn.Module,
) -> dict:
"""
Function for pieces that are longer than the training length of the model.
Split the input piece into chunks, run the model on them, and aggregate the predictions.
The spect is supposed to be a torch tensor of shape (time x bins), i.e., unbatched, and the output is also unbatched.
Args:
spect (torch.Tensor): the input piece
chunk_size (int): the length of the chunks
border_size (int): the size of the border that is discarded from the predictions
overlap_mode (str): how to handle overlaps between chunks
model (torch.nn.Module): the model to run
Returns:
dict: the model framewise predictions for the hole piece as a dictionary containing 'beat' and 'downbeat' predictions.
"""
# split the piece into chunks
chunks, starts = split_piece(
spect, chunk_size, border_size=border_size, avoid_short_end=True
)
# run the model
pred_chunks = [model(chunk.unsqueeze(0)) for chunk in chunks]
# remove the extra dimension in beat and downbeat prediction due to batch size 1
pred_chunks = [
{"beat": p["beat"][0], "downbeat": p["downbeat"][0]} for p in pred_chunks
]
piece_prediction_beat, piece_prediction_downbeat = aggregate_prediction(
pred_chunks,
starts,
spect.shape[0],
chunk_size,
border_size,
overlap_mode,
spect.device,
)
# save it to model_prediction
return {"beat": piece_prediction_beat, "downbeat": piece_prediction_downbeat}
class Spect2Frames:
"""
Class for extracting framewise beat and downbeat predictions (logits) from a spectrogram.
"""
def __init__(self, checkpoint_path="final0", device="cpu", float16=False):
super().__init__()
self.device = torch.device(device)
self.float16 = float16
self.model = load_model(checkpoint_path, self.device)
def spect2frames(self, spect):
with torch.inference_mode():
with torch.autocast(enabled=self.float16, device_type=self.device.type):
model_prediction = split_predict_aggregate(
spect=spect,
chunk_size=1500,
overlap_mode="keep_first",
border_size=6,
model=self.model,
)
return model_prediction["beat"].float(), model_prediction["downbeat"].float()
def __call__(self, spect):
return self.spect2frames(spect)
class Audio2Frames(Spect2Frames):
"""
Class for extracting framewise beat and downbeat predictions (logits) from an audio tensor.
"""
def __init__(self, checkpoint_path="final0", device="cpu", float16=False):
super().__init__(checkpoint_path, device, float16)
self.spect = LogMelSpect(device=self.device)
def signal2spect(self, signal, sr):
if signal.ndim == 2:
signal = signal.mean(1)
elif signal.ndim != 1:
raise ValueError(f"Expected 1D or 2D signal, got shape {signal.shape}")
if sr != 22050:
signal = soxr.resample(signal, in_rate=sr, out_rate=22050)
signal = torch.tensor(signal, dtype=torch.float32, device=self.device)
return self.spect(signal)
def __call__(self, signal, sr):
spect = self.signal2spect(signal, sr)
return self.spect2frames(spect)
class Audio2Beats(Audio2Frames):
"""
Class for extracting beat and downbeat positions (in seconds) from an audio tensor.
Args:
checkpoint_path (str): Path to the model checkpoint file. It can be a local path, a URL, or a key from the CHECKPOINT_URL dictionary. Default is "final0", which will load the model trained on all data except GTZAN with seed 0.
device (str): Device to use for inference. Default is "cpu".
float16 (bool): Whether to use half precision floating point arithmetic. Default is False.
dbn (bool): Whether to use the madmom DBN for post-processing. Default is False.
"""
def __init__(
self, checkpoint_path="final0", device="cpu", float16=False, dbn=False
):
super().__init__(checkpoint_path, device, float16)
self.frames2beats = Postprocessor(type="dbn" if dbn else "minimal")
def __call__(self, signal, sr):
beat_logits, downbeat_logits = super().__call__(signal, sr)
return self.frames2beats(beat_logits, downbeat_logits)
class File2Beats(Audio2Beats):
def __call__(self, audio_path):
signal, sr = load_audio(audio_path)
return super().__call__(signal, sr)
class File2File(File2Beats):
def __call__(self, audio_path, output_path):
downbeats, beats = super().__call__(audio_path)
save_beat_tsv(downbeats, beats, output_path)
================================================
FILE: beat_this/model/__init__.py
================================================
================================================
FILE: beat_this/model/beat_tracker.py
================================================
"""
Model definitions for the Beat This! beat tracker.
"""
import contextlib
from collections import OrderedDict
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from rotary_embedding_torch import RotaryEmbedding
from torch import nn
from beat_this.model import roformer
from beat_this.utils import replace_state_dict_key
class BeatThis(nn.Module):
"""
A neural network model for beat tracking. It is composed of three main components:
- a frontend that processes the input spectrogram,
- a series of transformer blocks that process the output of the frontend,
- a head that produces the final beat and downbeat predictions.
Args:
spect_dim (int): The dimension of the input spectrogram (default: 128).
transformer_dim (int): The dimension of the main transformer blocks (default: 512).
ff_mult (int): The multiplier for the feed-forward dimension in the transformer blocks (default: 4).
n_layers (int): The number of transformer blocks (default: 6).
head_dim (int): The dimension of each attention head for the partial transformers in the frontend and the transformer blocks (default: 32).
stem_dim (int): The out dimension of the stem convolutional layer (default: 32).
dropout (dict): A dictionary specifying the dropout rates for different parts of the model
(default: {"frontend": 0.1, "transformer": 0.2}).
sum_head (bool): Whether to use a SumHead for the final predictions (default: True) or plain independent projections.
partial_transformers (bool): Whether to include partial frequency- and time-transformers in the frontend (default: True)
"""
def __init__(
self,
spect_dim: int = 128,
transformer_dim: int = 512,
ff_mult: int = 4,
n_layers: int = 6,
head_dim: int = 32,
stem_dim: int = 32,
dropout: dict = {"frontend": 0.1, "transformer": 0.2},
sum_head: bool = True,
partial_transformers: bool = True,
):
super().__init__()
# shared rotary embedding for frontend blocks and transformer blocks
rotary_embed = RotaryEmbedding(head_dim)
# create the frontend
# - stem
stem = self.make_stem(spect_dim, stem_dim)
spect_dim //= 4 # frequencies were convolved with stride 4
# - three frontend blocks
frontend_blocks = []
dim = stem_dim
for _ in range(3):
frontend_blocks.append(
self.make_frontend_block(
dim,
dim * 2,
partial_transformers,
head_dim,
rotary_embed,
dropout["frontend"],
)
)
dim *= 2
spect_dim //= 2 # frequencies were convolved with stride 2
frontend_blocks = nn.Sequential(*frontend_blocks)
# - linear projection to transformer dimensionality
concat = Rearrange("b c f t -> b t (c f)")
linear = nn.Linear(dim * spect_dim, transformer_dim)
self.frontend = nn.Sequential(
OrderedDict(stem=stem, blocks=frontend_blocks, concat=concat, linear=linear)
)
# create the transformer blocks
assert (
transformer_dim % head_dim == 0
), "transformer_dim must be divisible by head_dim"
n_heads = transformer_dim // head_dim
self.transformer_blocks = roformer.Transformer(
dim=transformer_dim,
depth=n_layers,
heads=n_heads,
attn_dropout=dropout["transformer"],
ff_dropout=dropout["transformer"],
rotary_embed=rotary_embed,
ff_mult=ff_mult,
dim_head=head_dim,
norm_output=True,
)
# create the output heads
if sum_head:
self.task_heads = SumHead(transformer_dim)
else:
self.task_heads = Head(transformer_dim)
# init all weights
self.apply(self._init_weights)
@staticmethod
def make_stem(spect_dim: int, stem_dim: int) -> nn.Module:
return nn.Sequential(
OrderedDict(
rearrange_tf=Rearrange("b t f -> b f t"),
bn1d=nn.BatchNorm1d(spect_dim),
add_channel=Rearrange("b f t -> b 1 f t"),
conv2d=nn.Conv2d(
in_channels=1,
out_channels=stem_dim,
kernel_size=(4, 3),
stride=(4, 1),
padding=(0, 1),
bias=False,
),
bn2d=nn.BatchNorm2d(stem_dim),
activation=nn.GELU(),
)
)
@staticmethod
def make_frontend_block(
in_dim: int,
out_dim: int,
partial_transformers: bool = True,
head_dim: int | None = 32,
rotary_embed: RotaryEmbedding | None = None,
dropout: float = 0.1,
) -> nn.Module:
if partial_transformers and (head_dim is None or rotary_embed is None):
raise ValueError(
"Must specify head_dim and rotary_embed for using partial_transformers"
)
return nn.Sequential(
OrderedDict(
partial=(
PartialFTTransformer(
dim=in_dim,
dim_head=head_dim,
n_head=in_dim // head_dim,
rotary_embed=rotary_embed,
dropout=dropout,
)
if partial_transformers
else nn.Identity()
),
# conv block
conv2d=nn.Conv2d(
in_channels=in_dim,
out_channels=out_dim,
kernel_size=(2, 3),
stride=(2, 1),
padding=(0, 1),
bias=False,
),
# out_channels : 64, 128, 256
# freqs : 16, 8, 4 (due to the stride=2)
norm=nn.BatchNorm2d(out_dim),
activation=nn.GELU(),
)
)
@staticmethod
def _init_weights(module: nn.Module):
if isinstance(module, (nn.Linear, nn.Conv1d)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
torch.nn.init.kaiming_normal_(
module.weight, mode="fan_out", nonlinearity="relu"
)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
with torch.no_grad():
module.weight[module.padding_idx].fill_(0)
def forward(self, x):
x = self.frontend(x)
x = self.transformer_blocks(x)
x = self.task_heads(x)
return x
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# remove _orig_mod prefixes for compiled models
state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "")
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
# remove _orig_mod prefixes for compiled models
state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "")
return state_dict
class PartialRoformer(nn.Module):
"""
Takes a (batch, channels, freqs, time) input, applies self-attention and
a feed-forward block either only across frequencies or only across time.
Returns a tensor of the same shape as the input.
"""
def __init__(
self,
dim: int,
dim_head: int,
n_head: int,
direction: str,
rotary_embed: RotaryEmbedding,
dropout: float,
):
super().__init__()
assert dim % dim_head == 0, "dim must be divisible by dim_head"
assert dim // dim_head == n_head, "n_head must be equal to dim // dim_head"
self.direction = direction[0].lower()
if self.direction not in "ft":
raise ValueError(f"direction must be F or T, got {direction}")
self.attn = roformer.Attention(
dim,
heads=n_head,
dim_head=dim_head,
dropout=dropout,
rotary_embed=rotary_embed,
)
self.ff = roformer.FeedForward(dim, dropout=dropout)
def forward(self, x):
b = len(x)
if self.direction == "f":
pattern = "(b t) f c"
elif self.direction == "t":
pattern = "(b f) t c"
x = rearrange(x, f"b c f t -> {pattern}")
x = x + self.attn(x)
x = x + self.ff(x)
x = rearrange(x, f"{pattern} -> b c f t", b=b)
return x
class PartialFTTransformer(nn.Module):
"""
Takes a (batch, channels, freqs, time) input, applies self-attention and
a feed-forward block once across frequencies and once across time. Same
as applying two PartialRoformer() in sequence, but encapsulated in a single
module. Returns a tensor of the same shape as the input.
"""
def __init__(
self,
dim: int,
dim_head: int,
n_head: int,
rotary_embed: RotaryEmbedding,
dropout: float,
):
super().__init__()
assert dim % dim_head == 0, "dim must be divisible by dim_head"
assert dim // dim_head == n_head, "n_head must be equal to dim // dim_head"
# frequency directed partial transformer
self.attnF = roformer.Attention(
dim,
heads=n_head,
dim_head=dim_head,
dropout=dropout,
rotary_embed=rotary_embed,
)
self.ffF = roformer.FeedForward(dim, dropout=dropout)
# time directed partial transformer
self.attnT = roformer.Attention(
dim,
heads=n_head,
dim_head=dim_head,
dropout=dropout,
rotary_embed=rotary_embed,
)
self.ffT = roformer.FeedForward(dim, dropout=dropout)
def forward(self, x):
b = len(x)
# frequency directed partial transformer
x = rearrange(x, "b c f t -> (b t) f c")
x = x + self.attnF(x)
x = x + self.ffF(x)
# time directed partial transformer
x = rearrange(x, "(b t) f c ->(b f) t c", b=b)
x = x + self.attnT(x)
x = x + self.ffT(x)
x = rearrange(x, "(b f) t c -> b c f t", b=b)
return x
class SumHead(nn.Module):
"""
A PyTorch module that produces the final beat and downbeat prediction logits.
The beats are a sum of all beats and all downbeats predictions, to reduce the prediction
of downbeats which are not beats.
"""
def __init__(self, input_dim):
super().__init__()
self.beat_downbeat_lin = nn.Linear(input_dim, 2)
def forward(self, x):
beat_downbeat = self.beat_downbeat_lin(x)
# separate beat from downbeat
beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2)
# aggregate beats and downbeats prediction
# autocast to float16 disabled to avoid numerical issues causing NaNs
if hasattr(
torch.amp, "is_autocast_available"
) and not torch.amp.is_autocast_available(beat.device.type):
# but do not try disabling if the device does not support autocast
disable_autocast = contextlib.nullcontext()
else:
disable_autocast = torch.autocast(beat.device.type, enabled=False)
with disable_autocast:
beat = beat.float() + downbeat.float()
return {"beat": beat, "downbeat": downbeat}
class Head(nn.Module):
"""
A PyToch module that produces the final beat and downbeat prediction logits with independent linear layers outputs.
"""
def __init__(self, input_dim):
super().__init__()
self.beat_downbeat_lin = nn.Linear(input_dim, 2)
def forward(self, x):
beat_downbeat = self.beat_downbeat_lin(x)
# separate beat from downbeat
beat, downbeat = rearrange(beat_downbeat, "b t c -> c b t", c=2)
return {"beat": beat, "downbeat": downbeat}
================================================
FILE: beat_this/model/loss.py
================================================
"""
Loss definitions for the Beat This! beat tracker.
"""
import torch
import torch.nn.functional as F
class MaskedBCELoss(torch.nn.Module):
"""
Plain binary cross-entropy loss. Expects predictions to be given as logits,
and accepts an optional mask with zeros indicating the entries to ignore.
Args:
pos_weight (float): Weight for positive examples compared to negative
examples (default: 1)
"""
def __init__(self, pos_weight: float = 1):
super().__init__()
self.register_buffer(
"pos_weight",
torch.tensor(pos_weight, dtype=torch.get_default_dtype()),
persistent=False,
)
def forward(
self,
preds: torch.Tensor,
targets: torch.Tensor,
mask: torch.Tensor | None = None,
):
return F.binary_cross_entropy_with_logits(
preds, targets, weight=mask, pos_weight=self.pos_weight
)
class ShiftTolerantBCELoss(torch.nn.Module):
"""
BCE loss variant for sequence labeling that tolerates small shifts between
predictions and targets. This is accomplished by max-pooling the
predictions with a given tolerance and a stride of 1, so the gradient for a
positive label affects the largest prediction in a window around it.
Expects predictions to be given as logits, and accepts an optional mask
with zeros indicating the entries to ignore. Note that the edges of the
sequence will not receive a gradient, as it is assumed to be unknown
whether there is a nearby positive annotation.
Args:
pos_weight (float): Weight for positive examples compared to negative
examples (default: 1)
tolerance (int): Tolerated shift in time steps in each direction
(default: 3)
"""
def __init__(self, pos_weight: float = 1, tolerance: int = 3):
super().__init__()
self.register_buffer(
"pos_weight",
torch.tensor(pos_weight, dtype=torch.get_default_dtype()),
persistent=False,
)
self.tolerance = tolerance
def spread(self, x: torch.Tensor, factor: int = 1):
if self.tolerance == 0:
return x
return F.max_pool1d(x, 1 + 2 * factor * self.tolerance, 1)
def crop(self, x: torch.Tensor, factor: int = 1):
return x[..., factor * self.tolerance : -factor * self.tolerance or None]
def forward(
self,
preds: torch.Tensor,
targets: torch.Tensor,
mask: torch.Tensor | None = None,
):
# spread preds and crop targets to match
spreaded_preds = self.crop(self.spread(preds))
cropped_targets = self.crop(targets, factor=2)
# ignore around the positive targets
look_at = cropped_targets + (1 - self.spread(targets, factor=2))
if mask is not None: # consider padding and no-downbeat mask
look_at = look_at * self.crop(mask, factor=2)
# compute loss
return F.binary_cross_entropy_with_logits(
spreaded_preds,
cropped_targets,
weight=look_at,
pos_weight=self.pos_weight,
)
class SplittedShiftTolerantBCELoss(torch.nn.Module):
"""
Alternative implementation of ShiftTolerantBCELoss that splits the loss for
positive and negative targets. This is mainly provided as it may be a bit
easier to understand and compare with the Beat This! paper. Note that for
non-binary targets (e.g., with label smoothing), this implementation
matches the equation in the paper (Section 3.3), while ShiftTolerantBCELoss
deviates from it. For binary targets, the results are identical.
Args:
pos_weight (int): weight of positive targets
spread_preds (int): amount of temporal max-pooling applied to predictions
"""
def __init__(self, pos_weight: float = 1, tolerance: int = 3):
super().__init__()
self.tolerance = 3
self.spread_preds = tolerance
self.spread_targets = 2 * tolerance # targets are always spreaded twice as much
self.register_buffer(
"pos_weight",
torch.tensor(pos_weight, dtype=torch.get_default_dtype()),
persistent=False,
)
def spread(self, x: torch.Tensor, amount: int):
if amount:
return F.max_pool1d(x, 1 + 2 * amount, 1)
else:
return x
def crop(self, x: torch.Tensor, desired_length: int):
amount = (x.shape[-1] - desired_length) // 2
if amount > 0:
return x[..., amount:-amount]
elif amount == 0:
return x
else:
raise ValueError("Desired length must be smaller than input length")
def forward(self, preds: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):
output_length = targets.size(-1) - 2 * self.spread_targets
# compute loss for positive targets, we spread preds
preds = self.spread(preds, self.spread_preds)
# we crop preds and targets (and mask) to ignore problems at the edges due to the maxpool operation
cropped_preds = self.crop(preds, output_length)
cropped_targets = self.crop(targets, output_length)
cropped_mask = self.crop(mask, output_length)
loss_positive = F.binary_cross_entropy_with_logits(
cropped_preds,
cropped_targets,
weight=cropped_targets * cropped_mask,
pos_weight=self.pos_weight,
)
# compute loss for negative targets, we spread targets and preds (already spreaded above)
targets = self.spread(targets, self.spread_targets)
cropped_targets = self.crop(targets, output_length)
loss_negative = F.binary_cross_entropy_with_logits(
cropped_preds,
cropped_targets,
weight=(1 - cropped_targets) * cropped_mask,
pos_weight=self.pos_weight, # ensures identical results to the other implementation
)
# sum the two losses
return loss_positive + loss_negative
================================================
FILE: beat_this/model/pl_module.py
================================================
"""
Pytorch Lightning module, wraps a BeatThis model along with losses, metrics and
optimizers for training.
"""
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import mir_eval
import numpy as np
import torch
from pytorch_lightning import LightningModule
import beat_this.model.loss
from beat_this.inference import split_predict_aggregate
from beat_this.model.beat_tracker import BeatThis
from beat_this.model.postprocessor import Postprocessor
from beat_this.utils import replace_state_dict_key
class PLBeatThis(LightningModule):
def __init__(
self,
spect_dim=128,
fps=50,
transformer_dim=512,
ff_mult=4,
n_layers=6,
stem_dim=32,
dropout={"frontend": 0.1, "transformer": 0.2},
lr=0.0008,
weight_decay=0.01,
pos_weights={"beat": 1, "downbeat": 1},
head_dim=32,
loss_type="shift_tolerant_weighted_bce",
warmup_steps=1000,
max_epochs=100,
use_dbn=False,
eval_trim_beats=5,
sum_head=True,
partial_transformers=True,
):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.weight_decay = weight_decay
self.fps = fps
# create model
self.model = BeatThis(
spect_dim=spect_dim,
transformer_dim=transformer_dim,
ff_mult=ff_mult,
stem_dim=stem_dim,
n_layers=n_layers,
head_dim=head_dim,
dropout=dropout,
sum_head=sum_head,
partial_transformers=partial_transformers,
)
self.warmup_steps = warmup_steps
self.max_epochs = max_epochs
# set up the losses
self.pos_weights = pos_weights
if loss_type == "shift_tolerant_weighted_bce":
self.beat_loss = beat_this.model.loss.ShiftTolerantBCELoss(
pos_weight=pos_weights["beat"]
)
self.downbeat_loss = beat_this.model.loss.ShiftTolerantBCELoss(
pos_weight=pos_weights["downbeat"]
)
elif loss_type == "weighted_bce":
self.beat_loss = beat_this.model.loss.MaskedBCELoss(
pos_weight=pos_weights["beat"]
)
self.downbeat_loss = beat_this.model.loss.MaskedBCELoss(
pos_weight=pos_weights["downbeat"]
)
elif loss_type == "bce":
self.beat_loss = beat_this.model.loss.MaskedBCELoss()
self.downbeat_loss = beat_this.model.loss.MaskedBCELoss()
elif loss_type == "splitted_shift_tolerant_weighted_bce":
self.beat_loss = beat_this.model.loss.SplittedShiftTolerantBCELoss(
pos_weight=pos_weights["beat"]
)
self.downbeat_loss = beat_this.model.loss.SplittedShiftTolerantBCELoss(
pos_weight=pos_weights["downbeat"]
)
else:
raise ValueError(
"loss_type must be one of 'shift_tolerant_weighted_bce', 'weighted_bce', 'bce'"
)
self.postprocessor = Postprocessor(
type="dbn" if use_dbn else "minimal", fps=fps
)
self.eval_trim_beats = eval_trim_beats
self.metrics = Metrics(eval_trim_beats=eval_trim_beats)
def _compute_loss(self, batch, model_prediction):
beat_mask = batch["padding_mask"]
beat_loss = self.beat_loss(
model_prediction["beat"], batch["truth_beat"].float(), beat_mask
)
# downbeat mask considers padding and also pieces which don't have downbeat annotations
downbeat_mask = beat_mask * batch["downbeat_mask"][:, None]
downbeat_loss = self.downbeat_loss(
model_prediction["downbeat"], batch["truth_downbeat"].float(), downbeat_mask
)
# sum the losses and return them in a dictionary for logging
return {
"beat": beat_loss,
"downbeat": downbeat_loss,
"total": beat_loss + downbeat_loss,
}
def _compute_metrics(self, batch, postp_beat, postp_downbeat, step="val"):
""" """
# compute for beat
metrics_beat = self._compute_metrics_target(
batch, postp_beat, target="beat", step=step
)
# compute for downbeat
metrics_downbeat = self._compute_metrics_target(
batch, postp_downbeat, target="downbeat", step=step
)
# concatenate dictionaries
metrics = {**metrics_beat, **metrics_downbeat}
return metrics
def _compute_metrics_target(self, batch, postp_target, target, step):
def compute_item(pospt_pred, truth_orig_target):
# take the ground truth from the original version, so there are no quantization errors
piece_truth_time = np.frombuffer(truth_orig_target)
# run evaluation
metrics = self.metrics(piece_truth_time, pospt_pred, step=step)
return metrics
# if the input was not batched, postp_target is an array instead of a tuple of arrays
# make it a tuple for consistency
if not isinstance(postp_target, tuple):
postp_target = (postp_target,)
with ThreadPoolExecutor() as executor:
piecewise_metrics = list(
executor.map(
compute_item,
postp_target,
batch[f"truth_orig_{target}"],
)
)
# average the beat metrics across the dictionary
batch_metric = {
key + f"_{target}": np.mean([x[key] for x in piecewise_metrics])
for key in piecewise_metrics[0].keys()
}
return batch_metric
def log_losses(self, losses, batch_size, step="train"):
# log for separate targets
for target in "beat", "downbeat":
self.log(
f"{step}_loss_{target}",
losses[target].item(),
prog_bar=False,
on_step=False,
on_epoch=True,
batch_size=batch_size,
sync_dist=True,
)
# log total loss
self.log(
f"{step}_loss",
losses["total"].item(),
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch_size,
sync_dist=True,
)
def log_metrics(self, metrics, batch_size, step="val"):
for key, value in metrics.items():
self.log(
f"{step}_{key}",
value,
prog_bar=key.startswith("F-measure"),
on_step=False,
on_epoch=True,
batch_size=batch_size,
sync_dist=True,
)
def training_step(self, batch, batch_idx):
# run the model
model_prediction = self.model(batch["spect"])
# compute loss
losses = self._compute_loss(batch, model_prediction)
self.log_losses(losses, len(batch["spect"]), "train")
return losses["total"]
def validation_step(self, batch, batch_idx):
# run the model
model_prediction = self.model(batch["spect"])
# compute loss
losses = self._compute_loss(batch, model_prediction)
# postprocess the predictions
postp_beat, postp_downbeat = self.postprocessor(
model_prediction["beat"],
model_prediction["downbeat"],
batch["padding_mask"],
)
# compute the metrics
metrics = self._compute_metrics(batch, postp_beat, postp_downbeat, step="val")
# log
self.log_losses(losses, len(batch["spect"]), "val")
self.log_metrics(metrics, batch["spect"].shape[0], "val")
def test_step(self, batch, batch_idx):
metrics, model_prediction, _, _ = self.predict_step(batch, batch_idx)
losses = self._compute_loss(batch, model_prediction)
# log
self.log_losses(losses, len(batch["spect"]), "test")
self.log_metrics(metrics, batch["spect"].shape[0], "test")
def predict_step(
self,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
chunk_size: int = 1500,
overlap_mode: str = "keep_first",
) -> Any:
"""
Compute predictions and metrics for a batch (a dictionary with an "spect" key).
It splits up the audio into multiple chunks of chunk size,
which should correspond to the length of the sequence the model was trained with.
Potential overlaps between chunks can be handled in two ways:
by keeping the predictions of the excerpt coming first (overlap_mode='keep_first'), or
by keeping the predictions of the excerpt coming last (overlap_mode='keep_last').
Note that overlaps appear as the last excerpt is moved backwards
when it would extend over the end of the piece.
"""
if batch["spect"].shape[0] != 1:
raise ValueError(
"When predicting full pieces, only `batch_size=1` is supported"
)
if torch.any(~batch["padding_mask"]):
raise ValueError(
"When predicting full pieces, the Dataset must not pad inputs"
)
# compute border size according to the loss type
if hasattr(
self.beat_loss, "tolerance"
): # discard the edges that are affected by the max-pooling in the loss
border_size = 2 * self.beat_loss.tolerance
else:
border_size = 0
model_prediction = split_predict_aggregate(
batch["spect"][0], chunk_size, border_size, overlap_mode, self.model
)
# add the batch dimension back in the prediction for consistency
model_prediction = {
key: value.unsqueeze(0) for key, value in model_prediction.items()
}
# postprocess the predictions
postp_beat, postp_downbeat = self.postprocessor(
model_prediction["beat"], model_prediction["downbeat"], None
)
# compute the metrics
metrics = self._compute_metrics(batch, postp_beat, postp_downbeat, step="test")
return metrics, model_prediction, batch["dataset"], batch["spect_path"]
def configure_optimizers(self):
optimizer = torch.optim.AdamW
# only decay 2+-dimensional tensors, to exclude biases and norms
# (filtering on dimensionality idea taken from Kaparthy's nano-GPT)
params = [
{
"params": (
p for p in self.parameters() if p.requires_grad and p.ndim >= 2
),
"weight_decay": self.weight_decay,
},
{
"params": (
p for p in self.parameters() if p.requires_grad and p.ndim <= 1
),
"weight_decay": 0,
},
]
optimizer = optimizer(params, lr=self.lr)
self.lr_scheduler = CosineWarmupScheduler(
optimizer, self.warmup_steps, self.trainer.estimated_stepping_batches
)
result = dict(optimizer=optimizer)
result["lr_scheduler"] = {"scheduler": self.lr_scheduler, "interval": "step"}
return result
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# remove _orig_mod prefixes for compiled models
state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "")
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
# remove _orig_mod prefixes for compiled models
state_dict = replace_state_dict_key(state_dict, "_orig_mod.", "")
return state_dict
class Metrics:
def __init__(self, eval_trim_beats: int) -> None:
self.min_beat_time = eval_trim_beats
def __call__(self, truth, preds, step) -> Any:
truth = mir_eval.beat.trim_beats(truth, min_beat_time=self.min_beat_time)
preds = mir_eval.beat.trim_beats(preds, min_beat_time=self.min_beat_time)
if (
step == "val"
): # limit the metrics that are computed during validation to speed up training
fmeasure = mir_eval.beat.f_measure(truth, preds)
cemgil = mir_eval.beat.cemgil(truth, preds)
return {"F-measure": fmeasure, "Cemgil": cemgil}
elif step == "test": # compute all metrics during testing
CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(truth, preds)
fmeasure = mir_eval.beat.f_measure(truth, preds)
cemgil = mir_eval.beat.cemgil(truth, preds)
return {"F-measure": fmeasure, "Cemgil": cemgil, "CMLt": CMLt, "AMLt": AMLt}
else:
raise ValueError("step must be either val or test")
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
"""
Cosine annealing over `max_iters` steps with `warmup` linear warmup steps.
Optionally re-raises the learning rate for the final `raise_last` fraction
of total training time to `raise_to` of the full learning rate, again with
a linear warmup (useful for stochastic weight averaging).
"""
def __init__(self, optimizer, warmup, max_iters, raise_last=0, raise_to=0.5):
self.warmup = warmup
self.max_num_iters = int((1 - raise_last) * max_iters)
self.raise_to = raise_to
super().__init__(optimizer)
def get_lr(self):
lr_factor = self.get_lr_factor(step=self.last_epoch)
return [base_lr * lr_factor for base_lr in self.base_lrs]
def get_lr_factor(self, step):
if step < self.max_num_iters:
progress = step / self.max_num_iters
lr_factor = 0.5 * (1 + np.cos(np.pi * progress))
if step <= self.warmup:
lr_factor *= step / self.warmup
else:
progress = (step - self.max_num_iters) / self.warmup
lr_factor = self.raise_to * min(progress, 1)
return lr_factor
================================================
FILE: beat_this/model/postprocessor.py
================================================
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
class Postprocessor:
"""Postprocessor for the beat and downbeat predictions of the model.
The postprocessor takes the (framewise) model predictions (beat and downbeats) and the padding mask,
and returns the postprocessed beat and downbeat as list of times in seconds.
The beats and downbeats can be 1D arrays (for only 1 piece) or 2D arrays, if a batch of pieces is considered.
The output dimensionality is the same as the input dimensionality.
Two types of postprocessing are implemented:
- minimal: a simple postprocessing that takes the maximum of the framewise predictions,
and removes adjacent peaks.
- dbn: a postprocessing based on the Dynamic Bayesian Network proposed by Böck et al.
Args:
type (str): the type of postprocessing to apply. Either "minimal" or "dbn". Default is "minimal".
fps (int): the frames per second of the model framewise predictions. Default is 50.
"""
def __init__(self, type: str = "minimal", fps: int = 50):
assert type in ["minimal", "dbn"]
self.type = type
self.fps = fps
if type == "dbn":
from madmom.features.downbeats import DBNDownBeatTrackingProcessor
self.dbn = DBNDownBeatTrackingProcessor(
beats_per_bar=[3, 4],
min_bpm=55.0,
max_bpm=215.0,
fps=self.fps,
transition_lambda=100,
)
def __call__(
self,
beat: torch.Tensor,
downbeat: torch.Tensor,
padding_mask: torch.Tensor | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Apply postprocessing to the input beat and downbeat tensors. Works with batched and unbatched inputs.
The output is a list of times in seconds, or a list of lists of times in seconds, if the input is batched.
Args:
beat (torch.Tensor): The input beat tensor.
downbeat (torch.Tensor): The input downbeat tensor.
padding_mask (torch.Tensor, optional): The padding mask tensor. Defaults to None.
Returns:
torch.Tensor: The postprocessed beat tensor.
torch.Tensor: The postprocessed downbeat tensor.
"""
batched = False if beat.ndim == 1 else True
if padding_mask is None:
padding_mask = torch.ones_like(beat, dtype=torch.bool)
# if beat and downbeat are 1D tensors, add a batch dimension
if not batched:
beat = beat.unsqueeze(0)
downbeat = downbeat.unsqueeze(0)
padding_mask = padding_mask.unsqueeze(0)
if self.type == "minimal":
postp_beat, postp_downbeat = self.postp_minimal(
beat, downbeat, padding_mask
)
elif self.type == "dbn":
postp_beat, postp_downbeat = self.postp_dbn(beat, downbeat, padding_mask)
else:
raise ValueError("Invalid postprocessing type")
# remove the batch dimension if it was added
if not batched:
postp_beat = postp_beat[0]
postp_downbeat = postp_downbeat[0]
# update the model prediction dict
return postp_beat, postp_downbeat
def postp_minimal(self, beat, downbeat, padding_mask):
# concatenate beat and downbeat in the same tensor of shape (B, T, 2)
packed_pred = rearrange(
[beat, downbeat], "c b t -> b t c", b=beat.shape[0], t=beat.shape[1], c=2
)
# set padded elements to -1000 (= probability zero even in float64) so they don't influence the maxpool
pred_logits = packed_pred.masked_fill(~padding_mask.unsqueeze(-1), -1000)
# reshape to (2*B, T) to apply max pooling
pred_logits = rearrange(pred_logits, "b t c -> (c b) t")
# pick maxima within +/- 70ms
pred_peaks = pred_logits.masked_fill(
pred_logits != F.max_pool1d(pred_logits, 7, 1, 3), -1000
)
# keep maxima with over 0.5 probability (logit > 0)
pred_peaks = pred_peaks > 0
# rearrange back to two tensors of shape (B, T)
beat_peaks, downbeat_peaks = rearrange(
pred_peaks, "(c b) t -> c b t", b=beat.shape[0], t=beat.shape[1], c=2
)
# run the piecewise operations
with ThreadPoolExecutor() as executor:
postp_beat, postp_downbeat = zip(
*executor.map(
self._postp_minimal_item, beat_peaks, downbeat_peaks, padding_mask
)
)
return postp_beat, postp_downbeat
def _postp_minimal_item(self, padded_beat_peaks, padded_downbeat_peaks, mask):
"""Function to compute the operations that must be computed piece by piece, and cannot be done in batch."""
# unpad the predictions by truncating the padding positions
beat_peaks = padded_beat_peaks[mask]
downbeat_peaks = padded_downbeat_peaks[mask]
# pass from a boolean array to a list of times in frames.
beat_frame = torch.nonzero(beat_peaks).cpu().numpy()[:, 0]
downbeat_frame = torch.nonzero(downbeat_peaks).cpu().numpy()[:, 0]
# remove adjacent peaks
beat_frame = deduplicate_peaks(beat_frame, width=1)
downbeat_frame = deduplicate_peaks(downbeat_frame, width=1)
# convert from frame to seconds
beat_time = beat_frame / self.fps
downbeat_time = downbeat_frame / self.fps
# move the downbeat to the nearest beat
if (
len(beat_time) > 0
): # skip if there are no beats, like in the first training steps
for i, d_time in enumerate(downbeat_time):
beat_idx = np.argmin(np.abs(beat_time - d_time))
downbeat_time[i] = beat_time[beat_idx]
# remove duplicate downbeat times (if some db were moved to the same position)
downbeat_time = np.unique(downbeat_time)
return beat_time, downbeat_time
def postp_dbn(self, beat, downbeat, padding_mask):
beat_prob = beat.double().sigmoid()
downbeat_prob = downbeat.double().sigmoid()
# limit lower and upper bound, since 0 and 1 create problems in the DBN
epsilon = 1e-5
beat_prob = beat_prob * (1 - epsilon) + epsilon / 2
downbeat_prob = downbeat_prob * (1 - epsilon) + epsilon / 2
with ThreadPoolExecutor() as executor:
postp_beat, postp_downbeat = zip(
*executor.map(
self._postp_dbn_item, beat_prob, downbeat_prob, padding_mask
)
)
return postp_beat, postp_downbeat
def _postp_dbn_item(self, padded_beat_prob, padded_downbeat_prob, mask):
"""Function to compute the operations that must be computed piece by piece, and cannot be done in batch."""
# unpad the predictions by truncating the padding positions
beat_prob = padded_beat_prob[mask]
downbeat_prob = padded_downbeat_prob[mask]
# build an artificial multiclass prediction, as suggested by Böck et al.
# again we limit the lower bound to avoid problems with the DBN
epsilon = 1e-5
combined_act = np.vstack(
(
np.maximum(
beat_prob.cpu().numpy() - downbeat_prob.cpu().numpy(), epsilon / 2
),
downbeat_prob.cpu().numpy(),
)
).T
# run the DBN
dbn_out = self.dbn(combined_act)
postp_beat = dbn_out[:, 0]
postp_downbeat = dbn_out[dbn_out[:, 1] == 1][:, 0]
return postp_beat, postp_downbeat
def deduplicate_peaks(peaks, width=1) -> np.ndarray:
"""
Replaces groups of adjacent peak frame indices that are each not more
than `width` frames apart by the average of the frame indices.
"""
result = []
peaks = map(int, peaks) # ensure we get ordinary Python int objects
try:
p = next(peaks)
except StopIteration:
return np.array(result)
c = 1
for p2 in peaks:
if p2 - p <= width:
c += 1
p += (p2 - p) / c # update mean
else:
result.append(p)
p = p2
c = 1
result.append(p)
return np.array(result)
================================================
FILE: beat_this/model/roformer.py
================================================
"""
Transformer with rotary position embedding, adapted from Phil Wang's repository
at https://github.com/lucidrains/BS-RoFormer (under MIT License).
"""
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.nn import Module, ModuleList
# helper functions
def exists(val):
return val is not None
# norm
class RMSNorm(Module):
def __init__(self, size, dim=-1):
super().__init__()
self.scale = size**0.5
if dim >= 0:
raise ValueError(f"dim must be negative, got {dim}")
self.gamma = nn.Parameter(torch.ones((size,) + (1,) * (abs(dim) - 1)))
self.dim = dim
def forward(self, x):
return F.normalize(x, dim=self.dim) * self.scale * self.gamma
# feedforward
class FeedForward(Module):
def __init__(
self,
dim,
mult=4,
dropout=0.0,
dim_out=None,
):
super().__init__()
if dim_out is None:
dim_out = dim
dim_inner = int(dim * mult)
self.activation = nn.GELU()
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
self.activation,
nn.Dropout(dropout),
nn.Linear(dim_inner, dim_out),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# attention
class Attend(nn.Module):
def __init__(self, dropout=0.0, scale=None):
super().__init__()
self.dropout = dropout
self.scale = scale
def forward(self, q, k, v):
if exists(self.scale):
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
return F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout if self.training else 0.0
)
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.0,
rotary_embed=None,
gating=True,
):
super().__init__()
self.heads = heads
self.scale = dim_head**-0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(dropout=dropout)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
if gating:
self.to_gates = nn.Linear(dim, heads)
else:
self.to_gates = None
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
)
def forward(self, x):
x = self.norm(x)
q, k, v = rearrange(
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
)
if exists(self.rotary_embed):
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
out = self.attend(q, k, v)
if exists(self.to_gates):
gates = self.to_gates(x)
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# Roformer
class Transformer(Module):
def __init__(
self,
*,
dim,
depth,
dim_head=32,
heads=16,
attn_dropout=0.1,
ff_dropout=0.1,
ff_mult=4,
norm_output=True,
rotary_embed=None,
gating=True,
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.layers.append(
ModuleList(
[
Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
gating=gating,
),
ff,
]
)
)
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
x = self.norm(x)
return x
================================================
FILE: beat_this/preprocessing.py
================================================
import numpy as np
import torch
import torchaudio
def load_audio(path, dtype="float64"):
try:
waveform, samplerate = torchaudio.load(path, channels_first=False)
waveform = np.asanyarray(waveform.squeeze().numpy(), dtype=dtype)
return waveform, samplerate
except Exception:
# in case torchaudio fails, try soundfile
try:
import soundfile as sf
return sf.read(path, dtype=dtype)
except Exception:
# some files are not readable by soundfile, try madmom
try:
import madmom
return madmom.io.load_audio_file(str(path), dtype=dtype)
except Exception:
raise RuntimeError(f'Could not load audio from "{path}".')
class LogMelSpect(torch.nn.Module):
def __init__(
self,
sample_rate=22050,
n_fft=1024,
hop_length=441,
f_min=30,
f_max=11000,
n_mels=128,
mel_scale="slaney",
normalized="frame_length",
power=1,
log_multiplier=1000,
device="cpu",
):
super().__init__()
self.spect_class = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
mel_scale=mel_scale,
normalized=normalized,
power=power,
).to(device)
self.log_multiplier = log_multiplier
def forward(self, x):
"""Input is a waveform as a monodimensional array of shape T,
output is a 2D log mel spectrogram of shape (F,128)."""
return torch.log1p(self.log_multiplier * self.spect_class(x).T)
================================================
FILE: beat_this/utils.py
================================================
from itertools import chain
from pathlib import Path
import numpy as np
def index_to_framewise(index, length):
"""Convert an index to a framewise sequence"""
sequence = np.zeros(length, dtype=bool)
sequence[index] = True
return sequence
def filename_to_augmentation(filename):
"""Convert a filename to an augmentation factor."""
parts = Path(filename).stem.split("_")
augmentations = {}
for part in parts[1:]:
if part.startswith("ps"):
augmentations["shift"] = int(part[2:])
elif part.startswith("ts"):
augmentations["stretch"] = int(part[2:])
return augmentations
def infer_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray:
"""
From beat and downbeat times, infer a number for each beat such that each downbeat
is associated with a 1 and beats in between are counted upwards.
The function requires that all downbeats are also listed as beats.
Args:
beats (numpy.ndarray): Array of beat positions in seconds (including downbeats).
downbeats (numpy.ndarray): Array of downbeat positions in seconds.
Returns:
numbers (numpy.ndarray): Array of integer beat numbers.
"""
# check if all downbeats are beats
if not np.all(np.isin(downbeats, beats)):
raise ValueError("Not all downbeats are beats.")
# handle pickup measure, by considering the beat count of the first full measure
if len(downbeats) >= 2:
# find the number of beats between the first two downbeats
first_downbeat, second_downbeat = np.searchsorted(beats, downbeats[:2])
beats_in_first_measure = second_downbeat - first_downbeat
# find the number of beats before the first downbeat
pickup_beats = first_downbeat
# derive where to start counting
if pickup_beats < beats_in_first_measure:
start_counter = beats_in_first_measure - pickup_beats
else:
print(
"WARNING: There are more beats in the pickup measure than in the first measure. The beat count will start from 2 without trying to estimate the length of the pickup measure."
)
start_counter = 1
else:
print(
"WARNING: There are less than two downbeats in the predictions. Something may be wrong. The beat count will start from 2 without trying to estimate the length of the pickup measure."
)
start_counter = 1
# assemble the beat numbers
numbers = []
counter = start_counter
downbeats = chain(downbeats, [-1])
next_downbeat = next(downbeats)
for beat in beats:
if beat == next_downbeat:
counter = 1
next_downbeat = next(downbeats)
else:
counter += 1
numbers.append(counter)
return np.asarray(numbers)
def save_beat_tsv(beats: np.ndarray, downbeats: np.ndarray, outpath: str) -> None:
"""
Save beat information to a tab-separated file in the standard .beats format:
each line has a time in seconds, a tab, and a beat number (1 = downbeat).
The function requires that all downbeats are also listed as beats.
Args:
beats (numpy.ndarray): Array of beat positions in seconds (including downbeats).
downbeats (numpy.ndarray): Array of downbeat positions in seconds.
outpath (str): Path to the output TSV file.
Returns:
None
"""
# infer beat numbers
numbers = infer_beat_numbers(beats, downbeats)
# write the beat file
Path(outpath).parent.mkdir(parents=True, exist_ok=True)
try:
with open(outpath, "w") as f:
f.writelines(f"{beat}\t{number}\n" for beat, number in zip(beats, numbers))
except KeyboardInterrupt:
outpath.unlink() # avoid half-written files
def replace_state_dict_key(state_dict: dict, old: str, new: str):
"""Replaces `old` in all keys of `state_dict` with `new`."""
keys = list(state_dict.keys()) # take snapshot of the keys
for key in keys:
if old in key:
state_dict[key.replace(old, new)] = state_dict.pop(key)
return state_dict
================================================
FILE: beat_this_example.ipynb
================================================
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOW4OkTmphTrvw2IQLr+kxP",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/CPJKU/beat_this/blob/main/beat_this_example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Beat This! inference example\n",
"\n",
"We first need to install and load the package."
],
"metadata": {
"id": "87X_GXfoGwmj"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sxhsMCKdLOLO",
"collapsed": true
},
"outputs": [],
"source": [
"# install the beat_this package\n",
"!pip install https://github.com/CPJKU/beat_this/archive/main.zip\n",
"# on Google Colab, this one is faster:\n",
"#!pip install --no-deps rotary-embedding-torch https://github.com/CPJKU/beat_this/archive/main.zip\n",
"\n",
"# load the Python class for beat tracking\n",
"from beat_this.inference import File2Beats\n",
"from beat_this.inference import File2File"
]
},
{
"cell_type": "markdown",
"source": [
"## Run on demo file\n",
"\n",
"Now that all the dependencies have been installed and imported, let's run our system.\n",
"\n",
"In the next cell we:\n",
"- define the audio file we want to use as input. For now we use the example provided in the beat_this repo, but this can be changed (see instructions later);\n",
"- load the File2Beats class that produce a list of beats and downbeats given an audio file;\n",
"- apply the class to the audio file\n",
"- print the position in seconds of the first 20 beats and first 20 downbeats.\n"
],
"metadata": {
"id": "_0oYbH6P6Ji7"
}
},
{
"cell_type": "code",
"source": [
"!wget -c \"https://github.com/CPJKU/beat_this/raw/main/tests/It%20Don't%20Mean%20A%20Thing%20-%20Kings%20of%20Swing.mp3\"\n",
"audio_path = \"/content/It Don't Mean A Thing - Kings of Swing.mp3\"\n",
"\n",
"file2beats = File2Beats(checkpoint_path=\"final0\", dbn=False)\n",
"beats, downbeats = file2beats(audio_path)\n",
"\n",
"print(\"First 20 beats\", beats[:20])\n",
"print(\"First 20 downbeats\", downbeats[:20])"
],
"metadata": {
"id": "DHT6v-a-TbZx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can sonify the beats and downbeats as click on top of the audio file."
],
"metadata": {
"id": "lRjJFiexDGdn"
}
},
{
"cell_type": "code",
"source": [
"import IPython.display as ipd\n",
"import librosa\n",
"import numpy as np\n",
"import soundfile as sf\n",
"\n",
"audio, sr = sf.read(audio_path)\n",
"# make it mono if stereo\n",
"if len(audio.shape) > 1:\n",
" audio = np.mean(audio, axis=1)\n",
"\n",
"# sonify the beats and downbeats\n",
"# remove the beats that are also downbeats for a nicer sonification\n",
"beats = [b for b in beats if b not in downbeats]\n",
"audio_beat = librosa.clicks(times = beats, sr=sr, click_freq=1000, length=len(audio))\n",
"audio_downbeat = librosa.clicks(times = downbeats, sr=sr, click_freq=1500, length=len(audio))\n",
"\n",
"ipd.display(ipd.Audio(audio + audio_beat + audio_downbeat, rate=sr))"
],
"metadata": {
"id": "otG0NS_uCXSo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Run on your own file\n",
"\n",
"If you want to run on your own audio files follow the following instructions:\n",
"1. Click on the folder icon in the left vertical menu.\n",
"2. Click on the \"Upload to session storage\" icon with the upward pointing arrow.\n",
"\n",
" This will add an audio file to the current colab runtime (it could take some time, and you may need to refresh the file manager using the dedicated button to see the new file). You can copy the audio path by clicking on the three dots next to the file, then \"copy path\".\n",
"\n",
" For example, if you upload a file called `my_song.mp3`, the path will be `/content/my_song.mp3`.\n",
"\n",
"3. change the `audio_path` in the cell above with the path of your uploaded audio"
],
"metadata": {
"id": "hn83Sn1pWmt5"
}
},
{
"cell_type": "markdown",
"source": [
"You can also produce a list of beat and downbeat as tsv file, that you can download and import in Sonic Visualizer.\n",
"\n",
"To do this this, use the File2File function as below:"
],
"metadata": {
"id": "kP2gyplIEcWT"
}
},
{
"cell_type": "code",
"source": [
"file2file = File2File(checkpoint_path=\"final0\", dbn=False)\n",
"file2file(audio_path,output_path=\"output.beats\")"
],
"metadata": {
"id": "kTQK-d4JEbL7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"As you can see, the system is fast enough to work in a reasonable time even on CPU.\n",
"\n",
"For even faster inference, you can start a GPU session in Colab!"
],
"metadata": {
"id": "1Y1d-DvXFtVz"
}
},
{
"cell_type": "markdown",
"source": [
"## Batch processing multiple files\n",
"\n",
"To process multiple of your own audio files, upload them as described above, then run the `beat_this` command line tool:"
],
"metadata": {
"id": "vpoM0RvQdAMF"
}
},
{
"cell_type": "code",
"source": [
"!beat_this --model final0 /content/"
],
"metadata": {
"id": "qNOLbBplc_Nq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"It will produce a `.beats` file for every audio file that you can download again."
],
"metadata": {
"id": "_xNY_9DEdSEt"
}
}
]
}
================================================
FILE: hubconf.py
================================================
dependencies = [
"torch",
"torchaudio",
"numpy",
"rotary_embedding_torch",
"einops",
"soxr",
]
from beat_this.inference import (
load_model as beat_this,
BeatThis,
Spect2Frames,
Audio2Frames,
Audio2Beats,
File2Beats,
File2File,
)
================================================
FILE: launch_scripts/clean_checkpoints.py
================================================
import argparse
from pathlib import Path
import torch
def main(args):
# check if output path exists
if Path(args.output_path).exists():
print(f"Output path {args.output_path} already exists. Exiting.")
return
# load the lightning checkpoit
checkpoint = torch.load(args.input_path, map_location="cpu")
# clean and keep only the keys "state_dict" and "datamodule" to save space
checkpoint = {
k: v
for k, v in checkpoint.items()
if k
in [
"state_dict",
"datamodule_hyper_parameters",
"hyper_parameters",
"pytorch-lightning_version",
]
}
# remove the "data_dir" key from "datamodule_hyper_parameters" because it is a
# Posix path and creates problems when loading in Windows.
if "data_dir" in checkpoint["datamodule_hyper_parameters"]:
del checkpoint["datamodule_hyper_parameters"]["data_dir"]
# save the cleaned checkpoint
torch.save(checkpoint, args.output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-path", type=str, required=True)
parser.add_argument("--output-path", type=str, required=True)
args = parser.parse_args()
main(args)
================================================
FILE: launch_scripts/compute_paper_metrics.py
================================================
#!/usr/bin/env python3
import argparse
from pathlib import Path
import numpy as np
from pytorch_lightning import Trainer, seed_everything
from beat_this.dataset import BeatDataModule
from beat_this.inference import load_checkpoint
from beat_this.model.pl_module import PLBeatThis
from beat_this.utils import infer_beat_numbers
# for repeatability
seed_everything(0, workers=True)
def main(args):
if len(args.models) == 1:
print("Single model prediction for", args.models[0])
# single model prediction
checkpoint_path = args.models[0]
checkpoint = load_checkpoint(checkpoint_path)
# create datamodule
datamodule = datamodule_setup(checkpoint, args.num_workers, args.datasplit)
# create model and trainer
model, trainer = plmodel_setup(
checkpoint, args.eval_trim_beats, args.dbn, args.gpu
)
# predict
metrics, dataset, preds, piece = compute_predictions(
model,
trainer,
datamodule.predict_dataloader(),
return_preds=args.dump_predictions,
)
# compute averaged metrics
averaged_metrics = {k: np.mean(v) for k, v in metrics.items()}
# compute metrics averaged by dataset
dataset_metrics = {
k: {d: np.mean(v[dataset == d]) for d in np.unique(dataset)}
for k, v in metrics.items()
}
# print for dataset
print("Metrics")
for k, v in averaged_metrics.items():
print(f"{k}: {v}")
print("Dataset metrics")
for k, v in dataset_metrics.items():
print(k)
for d, value in v.items():
print(f"{d}: {value}")
print("------")
# dump predictions
if args.dump_predictions:
write_predictions(args.dump_predictions, preds, piece)
else: # multiple models
if args.aggregation_type == "mean-std":
if args.dump_predictions:
print(
"cannot dump predictions when doing inference for multiple models"
)
return
# computing result variability for the same dataset and different model seeds
# create datamodule only once, as we assume it is the same for all models
checkpoint = load_checkpoint(args.models[0])
datamodule = datamodule_setup(checkpoint, args.num_workers, args.datasplit)
# create model and trainer
all_metrics = []
for checkpoint_path in args.models:
checkpoint = load_checkpoint(checkpoint_path)
model, trainer = plmodel_setup(
checkpoint, args.eval_trim_beats, args.dbn, args.gpu
)
metrics, dataset, preds, piece = compute_predictions(
model, trainer, datamodule.predict_dataloader()
)
# compute averaged metrics for one model
averaged_metrics = {k: np.mean(v) for k, v in metrics.items()}
all_metrics.append(averaged_metrics)
# compute mean and standard deviations for all model averages
all_metrics_mean = {
k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0]
}
all_metrics_std = {
k: np.std([m[k] for m in all_metrics]) for k in all_metrics[0]
}
all_metrics_stats = {
k: (all_metrics_mean[k], all_metrics_std[k])
for k, v in all_metrics[0].items()
}
# print all metrics
print("Metrics")
for k, v in all_metrics_stats.items():
# round to 3 decimal places
print(f"{k}: {round(v[0],3)} +- {round(v[1],3)}")
elif args.aggregation_type == "k-fold":
# computing results in the K-fold setting. Every fold has a different dataset
all_piece_metrics = []
all_piece_dataset = []
all_piece_preds = []
all_piece = []
# create datamodule for each model
for i_model, checkpoint_path in enumerate(args.models):
print(f"Model {i_model+1}/{len(args.models)}")
checkpoint = load_checkpoint(checkpoint_path)
datamodule = datamodule_setup(
checkpoint, args.num_workers, args.datasplit
)
# create model and trainer
model, trainer = plmodel_setup(
checkpoint, args.eval_trim_beats, args.dbn, args.gpu
)
# predict
metrics, dataset, preds, piece = compute_predictions(
model,
trainer,
datamodule.predict_dataloader(),
return_preds=args.dump_predictions,
)
all_piece_metrics.append(metrics)
all_piece_dataset.append(dataset)
all_piece_preds.extend(preds)
all_piece.append(piece)
# aggregate across folds
all_piece_metrics = {
k: np.concatenate([m[k] for m in all_piece_metrics])
for k in all_piece_metrics[0]
}
all_piece_dataset = np.concatenate(all_piece_dataset)
all_piece = np.concatenate(all_piece)
# double check that there are no errors in the fold and there are not repeated pieces
assert len(all_piece) == len(
np.unique(all_piece)
), "There are repeated pieces in the folds"
dataset_metrics = {
k: {
d: np.mean(v[all_piece_dataset == d])
for d in np.unique(all_piece_dataset)
}
for k, v in all_piece_metrics.items()
}
# print for dataset
print("Dataset metrics")
for k, v in dataset_metrics.items():
print(k)
for d, value in v.items():
print(f"{d}: {round(value,3)}")
print("------")
# dump predictions
if args.dump_predictions:
write_predictions(args.dump_predictions, all_piece_preds, all_piece)
else:
raise ValueError(f"Unknown aggregation type {args.aggregation_type}")
def datamodule_setup(checkpoint, num_workers, datasplit):
# Load the datamodule
print("Creating datamodule")
data_dir = Path(__file__).parent.parent.relative_to(Path.cwd()) / "data"
datamodule_hparams = checkpoint["datamodule_hyper_parameters"]
# update the hparams with the ones from the arguments
if num_workers is not None:
datamodule_hparams["num_workers"] = num_workers
datamodule_hparams["predict_datasplit"] = datasplit
datamodule_hparams["data_dir"] = data_dir
datamodule = BeatDataModule(**datamodule_hparams)
datamodule.setup(stage="predict")
return datamodule
def plmodel_setup(checkpoint, eval_trim_beats, dbn, gpu):
"""
Set up the pytorch lightning model and trainer for evaluation.
Args:
checkpoint_path (dict): The dict containing the checkpoint to load.
eval_trim_beats (int or None): The number of beats to trim during evaluation. If None, the setting is taken from the pretrained model.
dbn (bool or None): Whether to use the Dynamic Bayesian Network (DBN) module during evaluation. If None, the default behavior from the pretrained model is used.
gpu (int): The index of the GPU device to use for training.
Returns:
tuple: A tuple containing the initialized pytorch lightning model and trainer.
"""
if eval_trim_beats is not None:
checkpoint["hyper_parameters"]["eval_trim_beats"] = eval_trim_beats
if dbn is not None:
checkpoint["hyper_parameters"]["use_dbn"] = dbn
model = PLBeatThis(**checkpoint["hyper_parameters"])
model.load_state_dict(checkpoint["state_dict"])
# set correct device and accelerator
if gpu >= 0:
devices = [gpu]
accelerator = "gpu"
else:
devices = 1
accelerator = "cpu"
# create trainer
trainer = Trainer(
accelerator=accelerator,
devices=devices,
logger=None,
deterministic=True,
precision="16-mixed",
)
return model, trainer
def compute_predictions(model, trainer, predict_dataloader, return_preds=False):
print("Computing predictions ...")
out = trainer.predict(model, predict_dataloader)
metrics = [o[0] for o in out]
if return_preds:
preds = [model.postprocessor(o[1]["beat"][0], o[1]["downbeat"][0]) for o in out]
else:
preds = None
dataset = np.asarray([o[2][0] for o in out])
piece = np.asarray([o[3][0] for o in out])
# convert metrics from list of per-batch dictionaries to a single dictionary with np arrays as values
metrics = {k: np.asarray([m[k] for m in metrics]) for k in metrics[0]}
return metrics, dataset, preds, piece
def write_predictions(fn, preds, piece):
np.savez(
fn,
**{
name: np.vstack([beats, infer_beat_numbers(beats, downbeats)]).T
for name, (beats, downbeats) in zip(piece, preds)
},
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Computes predictions for a given model and dataset, "
"prints metrics, and optionally dumps predictions to a given file."
)
parser.add_argument(
"--models",
type=str,
nargs="+",
required=True,
help="Local checkpoint files to use",
)
parser.add_argument(
"--datasplit",
type=str,
choices=("train", "val", "test"),
default="val",
help="data split to use: train, val or test " "(default: %(default)s)",
)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument(
"--num_workers", type=int, default=8, help="number of data loading workers "
)
parser.add_argument(
"--eval_trim_beats",
metavar="SECONDS",
type=float,
default=None,
help="Override whether to skip the first given seconds "
"per piece in evaluating (default: as stored in model)",
)
parser.add_argument(
"--dbn",
default=None,
action=argparse.BooleanOptionalAction,
help="override the option to use madmom postprocessing dbn",
)
parser.add_argument(
"--aggregation-type",
type=str,
choices=("mean-std", "k-fold"),
default="mean-std",
help="Type of aggregation to use for multiple models; ignored if only one model is given",
)
parser.add_argument(
"--dump-predictions",
metavar="FILENAME",
type=str,
default=None,
help="File to write predictions to, in .npz format (optional)",
)
args = parser.parse_args()
main(args)
================================================
FILE: launch_scripts/preprocess_audio.py
================================================
#!/usr/bin/env python3
import argparse
import concurrent.futures
import os
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import pandas as pd
import soxr
import torch
import torchaudio
from pedalboard import Pedalboard, PitchShift, time_stretch
from tqdm import tqdm
from beat_this.dataset.augment import precomputed_augmentation_filenames
from beat_this.preprocessing import LogMelSpect, load_audio
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
BASEPATH = Path(__file__).parent.parent.relative_to(Path.cwd())
def save_audio(path, waveform, samplerate, resample_from=None):
if resample_from and resample_from != samplerate:
waveform = soxr.resample(waveform, in_rate=resample_from, out_rate=samplerate)
try:
waveform = torch.as_tensor(np.asarray(waveform, dtype=np.float64))
torchaudio.save(
path, torch.atleast_2d(waveform), samplerate, bits_per_sample=16
)
except KeyboardInterrupt:
path.unlink() # avoid half-written files
raise
def save_spectrogram(path, spectrogram, dtype=np.float16):
try:
np.save(path, np.asarray(spectrogram, dtype=dtype))
except KeyboardInterrupt:
path.unlink() # avoid half-written files
raise
class SpectCreation:
def __init__(self, pitch_shift, time_stretch, audio_sr, mel_args, verbose=False):
"""
Initialize the SpectCreation class. This assume that the audio files have been preprocessed with all the requested augmentations and are stored in the `mono_tracks` directory with the proper naming defined in AudioPreprocessing.
Args:
pitch_shift (tuple or None): A tuple specifying the minimum and maximum (inclusive) pitch shift values considered from the available audio files.
If None, pitch shifting augmentation files will not be considered.
time_stretch (tuple or None): A tuple specifying the min/max and stride percentage to consider from the available audio files.
If None, time stretching augmentation files will not be considered.
audio_sr (int): The sample rate of the audio.
mel_args (dict): A dictionary of arguments to be passed to the MelSpectrogram class.
verbose (bool, optional): Whether to print verbose information. Defaults to False.
"""
super(SpectCreation, self).__init__()
# define the directories
self.audio_dir = BASEPATH / "data" / "audio"
self.mono_tracks_dir = self.audio_dir / "mono_tracks"
self.spectrograms_dir = self.audio_dir / "spectrograms"
self.annotations_dir = BASEPATH / "data" / "annotations"
if verbose:
print("Audio dir: ", self.audio_dir.absolute())
print("Mono tracks dir: ", self.mono_tracks_dir.absolute())
print("Spectrograms dir: ", self.spectrograms_dir.absolute())
print("Annotations dir: ", self.annotations_dir.absolute())
self.verbose = verbose
# remember the audio metadata
self.audio_sr = audio_sr
# create the mel spectrogram class
self.logspect_class = LogMelSpect(audio_sr, **mel_args)
# define the augmentations
self.augmentations = {}
if pitch_shift is not None:
self.augmentations["pitch"] = {"min": pitch_shift[0], "max": pitch_shift[1]}
if time_stretch is not None:
self.augmentations["tempo"] = {
"min": -time_stretch[0],
"max": time_stretch[0],
"stride": time_stretch[1],
}
# compute the names to consider according to the augmentations
self.filenames = precomputed_augmentation_filenames(self.augmentations, "wav")
def create_spects(self):
print("Creating spectrograms ...")
processed = 0
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for dataset_dir in self.mono_tracks_dir.iterdir():
for piece_dir in dataset_dir.iterdir():
futures.append(
executor.submit(
self.create_spect_piece,
piece_dir,
Path(dataset_dir.name)
/ "annotations"
/ "beats"
/ f"{piece_dir.name}.beats",
dataset_dir.name,
)
)
for future in tqdm(
concurrent.futures.as_completed(futures), total=len(futures)
):
if future.result():
processed += 1
print(f"Created {processed} spectrograms in {self.spectrograms_dir}")
def create_spect_piece(self, preprocessed_audio_folder, beat_path, dataset_name):
"""
Create spectrogram for a single audio piece.
This method creates a spectrogram for a single audio piece located in the `preprocessed_audio_folder`.
The beat annotations for the audio piece are loaded from the `beat_path` file.
The created spectrogram is saved in the `spectrograms_dir` directory.
Args:
preprocessed_audio_folder (Path): The path to the preprocessed audio folder.
beat_path (Path): The path to the beat annotations file.
dataset_name (str): The name of the dataset.
Returns:
metadata (list): A list containing the metadata of the created spectrogram.
"""
for filename in self.filenames:
if not (self.annotations_dir / beat_path).exists():
print(
f"beat annotation {beat_path} not found for {preprocessed_audio_folder}"
)
return
audio_path = preprocessed_audio_folder / filename
spect_path = (
self.spectrograms_dir
/ dataset_name
/ preprocessed_audio_folder.name
/ f"{Path(filename).stem}.npy"
)
if spect_path.exists():
if self.verbose:
print(f"Skipping {spect_path} because it exists")
else:
if self.verbose:
print(f"Computing {spect_path}")
waveform, sr = load_audio(audio_path)
assert (
sr == self.audio_sr
), f"Sample rate mismatch: {sr} != {self.audio_sr}"
# compute the mel spectrogram and scale the values with log(1 + 1000 * x)
spect = self.logspect_class(torch.tensor(waveform, dtype=torch.float32))
# save the spectrogram as numpy array
spect_path.parent.mkdir(parents=True, exist_ok=True)
save_spectrogram(spect_path, spect.numpy())
return True
class AudioPreprocessing(object):
def __init__(
self,
orig_audio_paths,
out_sr=22050,
aug_sr=44100,
ext="wav",
pitch_shift=(-5, 6),
time_stretch=(20, 4),
verbose=False,
):
"""
Class for converting audio files to mono, resampling, and applying augmentations.
Only use this if you want to start from new audio files, otherwise use the spectrograms provided in the repo.
Args:
orig_audio_paths (Path): The path to the file with the original audio paths for each dataset.
out_sr (int, optional): The output sample rate. Defaults to 22050.
aug_sr (int, optional): The sample rate for the augmentations. Defaults to 44100.
ext (str, optional): The extension of the audio files. Defaults to 'wav'.
pitch_shift (tuple, optional): A tuple specifying the minimum and maximum (inclusive) pitch shift values considered. Defaults to (-5, 6).
time_stretch (tuple, optional): A tuple specifying the min/max (inclusive) time stretch and stride in percentage considered. Defaults to (20, 4).
verbose (bool, optional): Whether to print verbose information. Defaults to False.
"""
super(AudioPreprocessing, self).__init__()
self.audio_dir = BASEPATH / "data" / "audio"
self.annotation_dir = BASEPATH / "data" / "annotations"
# load data_dir from audio_path.csv which has the format: dataset_name, audio_path
self.audio_dirs = {
row[0]: row[1] for row in pd.read_csv(orig_audio_paths, header=None).values
}
# check if annotations exists, otherwise tell how to obtain them
if not self.annotation_dir.exists():
raise RuntimeError(
f"{self.annotation_dir} missing, check instructions "
"in README.md how to obtain the annotations."
)
print(f"Annotations ready in {self.annotation_dir}")
self.out_sr = out_sr
self.aug_sr = aug_sr
self.ext = ext
self.pitch_shift = pitch_shift
if time_stretch:
# interpret tuple as (maximum percentage, stride)
time_stretch = range(
-time_stretch[0],
time_stretch[0] + 1,
time_stretch[1] if len(time_stretch) > 1 else 1,
)
self.time_stretch = time_stretch
self.verbose = verbose
def preprocess_audio(self):
print("Preprocessing audio files ...")
processed = 0
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for dataset_name, audio_dir in self.audio_dirs.items():
for audio_path in Path(audio_dir).iterdir():
if audio_path.stem[:12] in ("gtzan_speech", "gtzan_music_"):
continue
futures.append(
executor.submit(
self.process_audio_file, dataset_name, audio_path
)
)
for future in tqdm(
concurrent.futures.as_completed(futures), total=len(futures)
):
if future.result():
processed += 1
print("Processed", processed, "audio files")
def process_audio_file(self, dataset_name, audio_path):
annotation_dir = Path(self.annotation_dir, dataset_name, "annotations")
# load annotations
beat_path = Path(annotation_dir, "beats", audio_path.stem + ".beats")
if not beat_path.exists():
print(
f"beat annotation {beat_path} not found for {audio_path}",
)
return False
# create a folder with the name of the track
folder_path = Path(self.audio_dir, "mono_tracks", dataset_name, audio_path.stem)
# derive the name of the unaugmented file
mono_path = folder_path / f"track.{self.ext}"
# derive the name of all augmented files
augmentations = {
"pitch": {"min": self.pitch_shift[0], "max": self.pitch_shift[1]},
"tempo": {
"min": -self.time_stretch[0],
"max": self.time_stretch[0],
"stride": self.time_stretch[1],
},
}
augmentations_path = precomputed_augmentation_filenames(augmentations, self.ext)
# stop here if all files exists
if mono_path.exists() and all(
(folder_path / aug).exists() for aug in augmentations_path
):
if self.verbose:
print(f"All files in {folder_path} exists, skipping")
return True
# load audio
try:
waveform, sr = load_audio(audio_path)
except Exception as e:
print("Problem with loading waveform", audio_path, e)
return
folder_path.mkdir(parents=True, exist_ok=True)
if (
waveform.ndim == 1
and sr == self.out_sr
and audio_path.suffix == f".{self.ext}"
):
# shortcut: copy original file to mono path location
os.system("cp '{}' '{}'".format(audio_path, mono_path))
else:
# we need to do some conversions for the unaugmented file
if waveform.ndim != 1:
waveform = np.mean(waveform, axis=1)
if not mono_path.exists():
if sr != self.out_sr:
waveform_out = soxr.resample(
waveform, in_rate=sr, out_rate=self.out_sr
)
else:
waveform_out = waveform
# save mono file
save_audio(mono_path, waveform_out, self.out_sr)
if (self.pitch_shift or self.time_stretch) and (sr != self.aug_sr):
waveform = soxr.resample(waveform, in_rate=sr, out_rate=self.aug_sr)
# handle the requested augmentations
# pedalboard requires float32, convert
waveform = np.asarray(waveform, dtype=np.float32)
shifts = (
range(self.pitch_shift[0], self.pitch_shift[1] + 1)
if self.pitch_shift
else [0]
)
stretches = self.time_stretch if self.time_stretch else [0]
for shift in shifts: # pitch augmentation
augment_audio_file(
folder_path,
waveform,
aug_type="shift",
amount=shift,
aug_sr=self.aug_sr,
out_sr=self.out_sr,
ext=self.ext,
verbose=self.verbose,
)
for stretch in stretches: # tempo augmentation
augment_audio_file(
folder_path,
waveform,
aug_type="stretch",
amount=stretch,
aug_sr=self.aug_sr,
out_sr=self.out_sr,
ext=self.ext,
verbose=self.verbose,
)
return True
def augment_audio_file(
folder_path, waveform, aug_type, amount, aug_sr, out_sr, ext, verbose
):
# figure out the file name
if aug_type == "stretch":
stretch = amount
shift = 0
elif aug_type == "shift":
shift = amount
stretch = 0
else:
raise ValueError(f"Unknown augmentation mode {aug_type}")
suffix = ""
if shift != 0:
suffix = suffix + f"_ps{shift}"
if stretch != 0:
suffix = suffix + f"_ts{stretch}"
out_path = Path(folder_path, f"track{suffix}.{ext}")
# skip if it exists
if out_path.exists():
if verbose:
print(f"{out_path} exists, skipping")
return
# otherwise compute it and write it out
# time stretch or pitch shift alone
if aug_type == "shift":
if verbose:
print(f"computing {out_path} with {shift=}")
# pitch shift alone
board = Pedalboard(
[
PitchShift(semitones=shift),
]
)
# apply pedalboard
augmented = board(waveform, aug_sr)
else: # type == stretch
if verbose:
print(f"computing {out_path} with {stretch=}")
augmented = time_stretch(
waveform,
aug_sr,
stretch_factor=1 + stretch / 100,
pitch_shift_in_semitones=0.0,
).squeeze()
# save to file
if verbose:
print(f"writing {out_path}")
save_audio(out_path, augmented, out_sr, resample_from=aug_sr)
def create_npz(spect_dir, npz_file, augmentations, verbose):
"""Assemble spectrograms from a directory into an .npz file."""
if npz_file.exists():
if verbose:
print(f"{npz_file} already exists, skipping")
return
with ZipFile(npz_file, "w") as z:
for subdir in tqdm(sorted(spect_dir.iterdir()), leave=False):
if subdir.is_dir():
for fn in precomputed_augmentation_filenames(augmentations):
z.write(subdir / fn, subdir.name + "/" + fn)
def ints(value):
"""Parse a string containing a colon-separated tuple of integers."""
return value and tuple(map(int, value.split(":")))
def main(orig_audio_paths, pitch_shift, time_stretch, verbose):
# preprocess audio
dp = AudioPreprocessing(
orig_audio_paths=orig_audio_paths,
out_sr=22050,
aug_sr=44100,
pitch_shift=pitch_shift,
time_stretch=time_stretch,
verbose=verbose,
)
dp.preprocess_audio()
# compute spectrograms
mel_args = dict(
n_fft=1024,
hop_length=441,
f_min=30,
f_max=11000,
n_mels=128,
mel_scale="slaney",
normalized="frame_length",
power=1,
)
sc = SpectCreation(
pitch_shift=pitch_shift,
time_stretch=time_stretch,
audio_sr=22050,
mel_args=mel_args,
verbose=verbose,
)
sc.create_spects()
# assemble into NPZ files
print("Creating .npz spectrogram bundles...")
spect_dirs = [child for child in sc.spectrograms_dir.iterdir() if child.is_dir()]
for spect_dir in tqdm(spect_dirs):
create_npz(
spect_dir,
spect_dir.with_suffix(".npz"),
{} if spect_dir.name == "gtzan" else sc.augmentations,
verbose,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_audio_paths",
type=str,
help="path to the file with the original audio paths for each dataset (default: %(default)s)",
default="data/audio_paths.csv",
)
parser.add_argument(
"--pitch_shift",
metavar="LOW:HIGH",
type=str,
default="-5:6",
help="pitch shift in semitones (default: %(default)s)",
)
parser.add_argument(
"--time_stretch",
metavar="MAX:STRIDE",
type=str,
default="20:4",
help="time stretch in percentage and stride (default: %(default)s)",
)
parser.add_argument("--verbose", action="store_true", help="verbose output")
args = parser.parse_args()
main(
args.orig_audio_paths,
ints(args.pitch_shift),
ints(args.time_stretch),
args.verbose,
)
================================================
FILE: launch_scripts/train.py
================================================
import argparse
from pathlib import Path
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from beat_this.dataset import BeatDataModule
from beat_this.model.pl_module import PLBeatThis
def main(args):
# for repeatability
seed_everything(args.seed, workers=True)
print("Starting a new run with the following parameters:")
print(args)
params_str = f"{'noval ' if not args.val else ''}{'hung ' if args.hung_data else ''}{'fold' + str(args.fold) + ' ' if args.fold is not None else ''}{args.loss}-h{args.transformer_dim}-aug{args.tempo_augmentation}{args.pitch_augmentation}{args.mask_augmentation}{' nosumH ' if not args.sum_head else ''}{' nopartialT ' if not args.partial_transformers else ''}"
if args.logger == "wandb":
if args.resume_checkpoint and args.resume_id:
wandb_args = dict(id=args.resume_id, resume="must")
else:
wandb_args = {}
logger = WandbLogger(
project="beat_this", name=f"{args.name} {params_str}".strip(), **wandb_args
)
else:
logger = None
if args.force_flash_attention:
print("Forcing the use of the flash attention.")
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
data_dir = Path(__file__).parent.parent.relative_to(Path.cwd()) / "data"
checkpoint_dir = (
Path(__file__).parent.parent.relative_to(Path.cwd()) / "checkpoints"
)
augmentations = {}
if args.tempo_augmentation:
augmentations["tempo"] = {"min": -20, "max": 20, "stride": 4}
if args.pitch_augmentation:
augmentations["pitch"] = {"min": -5, "max": 6}
if args.mask_augmentation:
# kind, min_count, max_count, min_len, max_len, min_parts, max_parts
augmentations["mask"] = {
"kind": "permute",
"min_count": 1,
"max_count": 6,
"min_len": 0.1,
"max_len": 2,
"min_parts": 5,
"max_parts": 9,
}
datamodule = BeatDataModule(
data_dir,
batch_size=args.batch_size,
train_length=args.train_length,
spect_fps=args.fps,
num_workers=args.num_workers,
test_dataset="gtzan",
length_based_oversampling_factor=args.length_based_oversampling_factor,
augmentations=augmentations,
hung_data=args.hung_data,
no_val=not args.val,
fold=args.fold,
)
datamodule.setup(stage="fit")
# compute positive weights
pos_weights = datamodule.get_train_positive_weights(widen_target_mask=3)
print("Using positive weights: ", pos_weights)
dropout = {
"frontend": args.frontend_dropout,
"transformer": args.transformer_dropout,
}
pl_model = PLBeatThis(
spect_dim=128,
fps=50,
transformer_dim=args.transformer_dim,
ff_mult=4,
n_layers=args.n_layers,
stem_dim=32,
dropout=dropout,
lr=args.lr,
weight_decay=args.weight_decay,
pos_weights=pos_weights,
head_dim=32,
loss_type=args.loss,
warmup_steps=args.warmup_steps,
max_epochs=args.max_epochs,
use_dbn=args.dbn,
eval_trim_beats=args.eval_trim_beats,
sum_head=args.sum_head,
partial_transformers=args.partial_transformers,
)
for part in args.compile:
if hasattr(pl_model.model, part):
setattr(pl_model.model, part, torch.compile(getattr(pl_model.model, part)))
print("Will compile model", part)
else:
raise ValueError("The model is missing the part", part, "to compile")
callbacks = [LearningRateMonitor(logging_interval="step")]
# save only the last model
callbacks.append(
ModelCheckpoint(
every_n_epochs=1,
dirpath=str(checkpoint_dir),
filename=f"{args.name} S{args.seed} {params_str}".strip(),
)
)
trainer = Trainer(
max_epochs=args.max_epochs,
accelerator="auto",
devices=[args.gpu],
num_sanity_val_steps=1,
logger=logger,
callbacks=callbacks,
log_every_n_steps=1,
precision="16-mixed",
accumulate_grad_batches=args.accumulate_grad_batches,
check_val_every_n_epoch=args.val_frequency,
)
trainer.fit(pl_model, datamodule, ckpt_path=args.resume_checkpoint)
trainer.test(pl_model, datamodule)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="")
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument(
"--force-flash-attention", default=False, action=argparse.BooleanOptionalAction
)
parser.add_argument(
"--compile",
action="store",
nargs="*",
type=str,
default=["frontend", "transformer_blocks", "task_heads"],
help="Which model parts to compile, among frontend, transformer_encoder, task_heads",
)
parser.add_argument("--n-layers", type=int, default=6)
parser.add_argument("--transformer-dim", type=int, default=512)
parser.add_argument(
"--frontend-dropout",
type=float,
default=0.1,
help="dropout rate to apply in the frontend",
)
parser.add_argument(
"--transformer-dropout",
type=float,
default=0.2,
help="dropout rate to apply in the main transformer blocks",
)
parser.add_argument("--lr", type=float, default=0.0008)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--logger", type=str, choices=["wandb", "none"], default="none")
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--n-heads", type=int, default=16)
parser.add_argument("--fps", type=int, default=50, help="The spectrograms fps.")
parser.add_argument(
"--loss",
type=str,
default="shift_tolerant_weighted_bce",
choices=[
"shift_tolerant_weighted_bce",
"fast_shift_tolerant_weighted_bce",
"weighted_bce",
"bce",
],
help="The loss to use",
)
parser.add_argument(
"--warmup-steps", type=int, default=1000, help="warmup steps for optimizer"
)
parser.add_argument(
"--max-epochs", type=int, default=100, help="max epochs for training"
)
parser.add_argument(
"--batch-size", type=int, default=8, help="batch size for training"
)
parser.add_argument("--accumulate-grad-batches", type=int, default=8)
parser.add_argument(
"--train-length",
type=int,
default=1500,
help="maximum seq length for training in frames",
)
parser.add_argument(
"--dbn",
default=False,
action=argparse.BooleanOptionalAction,
help="use madmom postprocessing DBN",
)
parser.add_argument(
"--eval-trim-beats",
metavar="SECONDS",
type=float,
default=5,
help="Skip the first given seconds per piece in evaluating (default: %(default)s)",
)
parser.add_argument(
"--val-frequency",
metavar="N",
type=int,
default=5,
help="validate every N epochs (default: %(default)s)",
)
parser.add_argument(
"--tempo-augmentation",
default=True,
action=argparse.BooleanOptionalAction,
help="Use precomputed tempo aumentation",
)
parser.add_argument(
"--pitch-augmentation",
default=True,
action=argparse.BooleanOptionalAction,
help="Use precomputed pitch aumentation",
)
parser.add_argument(
"--mask-augmentation",
default=True,
action=argparse.BooleanOptionalAction,
help="Use online mask aumentation",
)
parser.add_argument(
"--sum-head",
default=True,
action=argparse.BooleanOptionalAction,
help="Use SumHead instead of two separate Linear heads",
)
parser.add_argument(
"--partial-transformers",
default=True,
action=argparse.BooleanOptionalAction,
help="Use Partial transformers in the frontend",
)
parser.add_argument(
"--length-based-oversampling-factor",
type=float,
default=0.65,
help="The factor to oversample the long pieces in the dataset. Set to 0 to only take one excerpt for each piece.",
)
parser.add_argument(
"--val",
default=True,
action=argparse.BooleanOptionalAction,
help="Train on all data, including validation data, escluding test data. The validation metrics will still be computed, but they won't carry any meaning.",
)
parser.add_argument(
"--hung-data",
default=False,
action=argparse.BooleanOptionalAction,
help="Limit the training to Hung et al. data. The validation will still be computed on all datasets.",
)
parser.add_argument(
"--fold",
type=int,
default=None,
help="If given, the CV fold number to *not* train on (0-based).",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Seed for the random number generators.",
)
parser.add_argument(
"--resume-checkpoint",
type=str,
default=None,
help="Resume training from a local checkpoint.",
)
parser.add_argument(
"--resume-id",
type=str,
default=None,
help="When resuming with --resume-checkpoint, optionally provide the wandb id to continue logging to.",
)
args = parser.parse_args()
main(args)
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[project]
name = "beat-this"
version = "1.1.0"
description = "Beat This! beat tracker"
readme = "README.md"
classifiers = [
"Intended Audience :: Science/Research",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Development Status :: 5 - Production/Stable",
"Programming Language :: Python :: 3",
]
authors = [
{name = "Francesco Foscarin", email = "francesco.foscarin@jku.at"},
{name = "Jan Schlüter", email = "jan.schlueter@jku.at"},
]
requires-python = ">=3"
dependencies = [
"numpy>=1.20",
"torch>=2",
"torchaudio",
"einops",
"rotary-embedding-torch",
"soxr",
]
license = "MIT"
license-files = ["LICENSE"]
[tool.setuptools.package-dir]
beat_this = "beat_this"
[project.urls]
Repository = "https://github.com/CPJKU/beat_this"
Issues = "https://github.com/CPJKU/beat_this/issues"
Changelog = "https://github.com/CPJKU/beat_this/blob/main/CHANGELOG.md"
[project.scripts]
beat_this = "beat_this.cli:main"
================================================
FILE: requirements.txt
================================================
# This is a set of known working versions for inference, documented for a
# distant future.
# We recommend following the requirements section in our README.md instead.
einops==0.8.0
numpy==1.26.4
rotary_embedding_torch==0.6.4
soxr==0.3.7
torch==2.3.1
torchaudio==2.3.1
tqdm==4.66.4
================================================
FILE: tests/test_inference.py
================================================
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from beat_this.inference import Audio2Frames, File2Beats
def test_File2Beat():
f2b = File2Beats()
audio_path = Path("tests/It Don't Mean A Thing - Kings of Swing.mp3")
beat, downbeat = f2b(audio_path)
assert isinstance(beat, np.ndarray)
assert isinstance(downbeat, np.ndarray)
def test_Audio2Frames():
a2f = Audio2Frames()
audio_path = Path("tests/It Don't Mean A Thing - Kings of Swing.mp3")
# load audio
audio, sr = sf.read(audio_path)
beat, downbeat = a2f(audio, sr)
assert isinstance(beat, torch.Tensor)
assert isinstance(downbeat, torch.Tensor)
gitextract_r0d2sw0_/
├── .github/
│ └── workflows/
│ └── pypi.yml
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── beat_this/
│ ├── __init__.py
│ ├── cli.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── augment.py
│ │ ├── dataset.py
│ │ └── mmnpz.py
│ ├── inference.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── beat_tracker.py
│ │ ├── loss.py
│ │ ├── pl_module.py
│ │ ├── postprocessor.py
│ │ └── roformer.py
│ ├── preprocessing.py
│ └── utils.py
├── beat_this_example.ipynb
├── hubconf.py
├── launch_scripts/
│ ├── clean_checkpoints.py
│ ├── compute_paper_metrics.py
│ ├── preprocess_audio.py
│ └── train.py
├── pyproject.toml
├── requirements.txt
└── tests/
└── test_inference.py
SYMBOL INDEX (181 symbols across 17 files)
FILE: beat_this/cli.py
function get_parser (line 22) | def get_parser():
function derive_output_path (line 92) | def derive_output_path(input_path, suffix, append, output=None, parent=N...
function run (line 114) | def run(
function main (line 194) | def main():
FILE: beat_this/dataset/augment.py
function augment_pitchtempo (line 5) | def augment_pitchtempo(item, augmentations):
function augment_pitch (line 42) | def augment_pitch(item, pitch_params):
function augment_tempo (line 50) | def augment_tempo(item, tempo_params):
function stretch_annotations (line 60) | def stretch_annotations(item, percentage):
function shift_annotations (line 71) | def shift_annotations(item, semitones):
function stretch_filename (line 76) | def stretch_filename(item, percentage):
function shift_filename (line 85) | def shift_filename(item, semitones):
function number_of_precomputed_augmentations (line 94) | def number_of_precomputed_augmentations(augmentations):
function precomputed_augmentation_filenames (line 105) | def precomputed_augmentation_filenames(augmentations, ext="npy"):
function augment_mask_ (line 129) | def augment_mask_(spect, augmentations: dict, fps: int):
function apply_mask_excerpt (line 177) | def apply_mask_excerpt(excerpt, kind, min_parts, max_parts):
FILE: beat_this/dataset/dataset.py
class BeatTrackingDataset (line 23) | class BeatTrackingDataset(Dataset):
method __init__ (line 38) | def __init__(
method _load_dataset_infos (line 81) | def _load_dataset_infos(self, datasets):
method _load_spect_bundles (line 88) | def _load_spect_bundles(self, datasets):
method _load_dataset_item (line 96) | def _load_dataset_item(self, item_name):
method _get_spect (line 146) | def _get_spect(self, item):
method get_frame_count (line 154) | def get_frame_count(self, index):
method get_beat_count (line 158) | def get_beat_count(self, index):
method get_downbeat_count (line 162) | def get_downbeat_count(self, index):
method __len__ (line 166) | def __len__(self):
method __getitem__ (line 169) | def __getitem__(self, index):
class BeatDataModule (line 247) | class BeatDataModule(pl.LightningDataModule):
method __init__ (line 268) | def __init__(
method setup (line 305) | def setup(self, stage):
method train_dataloader (line 448) | def train_dataloader(self):
method val_dataloader (line 458) | def val_dataloader(self):
method test_dataloader (line 465) | def test_dataloader(self):
method predict_dataloader (line 468) | def predict_dataloader(self):
method get_train_positive_weights (line 473) | def get_train_positive_weights(self, widen_target_mask=3):
function prepare_annotations (line 512) | def prepare_annotations(item, start_frame, end_frame, fps):
FILE: beat_this/dataset/mmnpz.py
class MemmappedNpzFile (line 12) | class MemmappedNpzFile(Mapping):
method __init__ (line 38) | def __init__(self, fn: str, cache: bool = True, preload: bool = False):
method load (line 54) | def load(self, name: str):
method close (line 78) | def close(self):
method __enter__ (line 83) | def __enter__(self):
method __exit__ (line 86) | def __exit__(self, exc_type, exc_value, traceback):
method __iter__ (line 89) | def __iter__(self):
method __len__ (line 92) | def __len__(self):
method __getitem__ (line 95) | def __getitem__(self, key: str):
method __contains__ (line 106) | def __contains__(self, key: str):
class MemoryviewIO (line 111) | class MemoryviewIO(object):
method __init__ (line 116) | def __init__(self, buffer):
method seek (line 123) | def seek(self, offset, whence=0):
method read (line 131) | def read(self, size=-1):
method tell (line 138) | def tell(self):
FILE: beat_this/inference.py
function load_checkpoint (line 16) | def load_checkpoint(checkpoint_path: str, device: str | torch.device = "...
function load_model (line 56) | def load_model(
function zeropad (line 90) | def zeropad(spect: torch.Tensor, left: int = 0, right: int = 0):
function split_piece (line 100) | def split_piece(
function aggregate_prediction (line 138) | def aggregate_prediction(
function split_predict_aggregate (line 188) | def split_predict_aggregate(
class Spect2Frames (line 233) | class Spect2Frames:
method __init__ (line 238) | def __init__(self, checkpoint_path="final0", device="cpu", float16=Fal...
method spect2frames (line 244) | def spect2frames(self, spect):
method __call__ (line 256) | def __call__(self, spect):
class Audio2Frames (line 260) | class Audio2Frames(Spect2Frames):
method __init__ (line 265) | def __init__(self, checkpoint_path="final0", device="cpu", float16=Fal...
method signal2spect (line 269) | def signal2spect(self, signal, sr):
method __call__ (line 279) | def __call__(self, signal, sr):
class Audio2Beats (line 284) | class Audio2Beats(Audio2Frames):
method __init__ (line 295) | def __init__(
method __call__ (line 301) | def __call__(self, signal, sr):
class File2Beats (line 306) | class File2Beats(Audio2Beats):
method __call__ (line 307) | def __call__(self, audio_path):
class File2File (line 312) | class File2File(File2Beats):
method __call__ (line 313) | def __call__(self, audio_path, output_path):
FILE: beat_this/model/beat_tracker.py
class BeatThis (line 18) | class BeatThis(nn.Module):
method __init__ (line 38) | def __init__(
method make_stem (line 109) | def make_stem(spect_dim: int, stem_dim: int) -> nn.Module:
method make_frontend_block (line 129) | def make_frontend_block(
method _init_weights (line 171) | def _init_weights(module: nn.Module):
method forward (line 188) | def forward(self, x):
method _load_from_state_dict (line 194) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
method state_dict (line 199) | def state_dict(self, *args, **kwargs):
class PartialRoformer (line 206) | class PartialRoformer(nn.Module):
method __init__ (line 213) | def __init__(
method forward (line 238) | def forward(self, x):
class PartialFTTransformer (line 251) | class PartialFTTransformer(nn.Module):
method __init__ (line 259) | def __init__(
method forward (line 290) | def forward(self, x):
class SumHead (line 304) | class SumHead(nn.Module):
method __init__ (line 311) | def __init__(self, input_dim):
method forward (line 315) | def forward(self, x):
class Head (line 333) | class Head(nn.Module):
method __init__ (line 338) | def __init__(self, input_dim):
method forward (line 342) | def forward(self, x):
FILE: beat_this/model/loss.py
class MaskedBCELoss (line 9) | class MaskedBCELoss(torch.nn.Module):
method __init__ (line 19) | def __init__(self, pos_weight: float = 1):
method forward (line 27) | def forward(
class ShiftTolerantBCELoss (line 38) | class ShiftTolerantBCELoss(torch.nn.Module):
method __init__ (line 56) | def __init__(self, pos_weight: float = 1, tolerance: int = 3):
method spread (line 65) | def spread(self, x: torch.Tensor, factor: int = 1):
method crop (line 70) | def crop(self, x: torch.Tensor, factor: int = 1):
method forward (line 73) | def forward(
class SplittedShiftTolerantBCELoss (line 95) | class SplittedShiftTolerantBCELoss(torch.nn.Module):
method __init__ (line 109) | def __init__(self, pos_weight: float = 1, tolerance: int = 3):
method spread (line 120) | def spread(self, x: torch.Tensor, amount: int):
method crop (line 126) | def crop(self, x: torch.Tensor, desired_length: int):
method forward (line 135) | def forward(self, preds: torch.Tensor, targets: torch.Tensor, mask: to...
FILE: beat_this/model/pl_module.py
class PLBeatThis (line 21) | class PLBeatThis(LightningModule):
method __init__ (line 22) | def __init__(
method _compute_loss (line 99) | def _compute_loss(self, batch, model_prediction):
method _compute_metrics (line 116) | def _compute_metrics(self, batch, postp_beat, postp_downbeat, step="va...
method _compute_metrics_target (line 132) | def _compute_metrics_target(self, batch, postp_target, target, step):
method log_losses (line 164) | def log_losses(self, losses, batch_size, step="train"):
method log_metrics (line 187) | def log_metrics(self, metrics, batch_size, step="val"):
method training_step (line 199) | def training_step(self, batch, batch_idx):
method validation_step (line 207) | def validation_step(self, batch, batch_idx):
method test_step (line 224) | def test_step(self, batch, batch_idx):
method predict_step (line 231) | def predict_step(
method configure_optimizers (line 279) | def configure_optimizers(self):
method _load_from_state_dict (line 308) | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
method state_dict (line 313) | def state_dict(self, *args, **kwargs):
class Metrics (line 320) | class Metrics:
method __init__ (line 321) | def __init__(self, eval_trim_beats: int) -> None:
method __call__ (line 324) | def __call__(self, truth, preds, step) -> Any:
class CosineWarmupScheduler (line 342) | class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
method __init__ (line 350) | def __init__(self, optimizer, warmup, max_iters, raise_last=0, raise_t...
method get_lr (line 356) | def get_lr(self):
method get_lr_factor (line 360) | def get_lr_factor(self, step):
FILE: beat_this/model/postprocessor.py
class Postprocessor (line 9) | class Postprocessor:
method __init__ (line 24) | def __init__(self, type: str = "minimal", fps: int = 50):
method __call__ (line 39) | def __call__(
method postp_minimal (line 85) | def postp_minimal(self, beat, downbeat, padding_mask):
method _postp_minimal_item (line 113) | def _postp_minimal_item(self, padded_beat_peaks, padded_downbeat_peaks...
method postp_dbn (line 138) | def postp_dbn(self, beat, downbeat, padding_mask):
method _postp_dbn_item (line 153) | def _postp_dbn_item(self, padded_beat_prob, padded_downbeat_prob, mask):
function deduplicate_peaks (line 176) | def deduplicate_peaks(peaks, width=1) -> np.ndarray:
FILE: beat_this/model/roformer.py
function exists (line 15) | def exists(val):
class RMSNorm (line 22) | class RMSNorm(Module):
method __init__ (line 23) | def __init__(self, size, dim=-1):
method forward (line 31) | def forward(self, x):
class FeedForward (line 38) | class FeedForward(Module):
method __init__ (line 39) | def __init__(
method forward (line 60) | def forward(self, x):
class Attend (line 67) | class Attend(nn.Module):
method __init__ (line 68) | def __init__(self, dropout=0.0, scale=None):
method forward (line 73) | def forward(self, q, k, v):
class Attention (line 83) | class Attention(Module):
method __init__ (line 84) | def __init__(
method forward (line 114) | def forward(self, x):
class Transformer (line 138) | class Transformer(Module):
method __init__ (line 139) | def __init__(
method forward (line 176) | def forward(self, x):
FILE: beat_this/preprocessing.py
function load_audio (line 6) | def load_audio(path, dtype="float64"):
class LogMelSpect (line 27) | class LogMelSpect(torch.nn.Module):
method __init__ (line 28) | def __init__(
method forward (line 56) | def forward(self, x):
FILE: beat_this/utils.py
function index_to_framewise (line 7) | def index_to_framewise(index, length):
function filename_to_augmentation (line 14) | def filename_to_augmentation(filename):
function infer_beat_numbers (line 26) | def infer_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.n...
function save_beat_tsv (line 79) | def save_beat_tsv(beats: np.ndarray, downbeats: np.ndarray, outpath: str...
function replace_state_dict_key (line 105) | def replace_state_dict_key(state_dict: dict, old: str, new: str):
FILE: launch_scripts/clean_checkpoints.py
function main (line 7) | def main(args):
FILE: launch_scripts/compute_paper_metrics.py
function main (line 17) | def main(args):
function datamodule_setup (line 159) | def datamodule_setup(checkpoint, num_workers, datasplit):
function plmodel_setup (line 174) | def plmodel_setup(checkpoint, eval_trim_beats, dbn, gpu):
function compute_predictions (line 213) | def compute_predictions(model, trainer, predict_dataloader, return_preds...
function write_predictions (line 228) | def write_predictions(fn, preds, piece):
FILE: launch_scripts/preprocess_audio.py
function save_audio (line 24) | def save_audio(path, waveform, samplerate, resample_from=None):
function save_spectrogram (line 37) | def save_spectrogram(path, spectrogram, dtype=np.float16):
class SpectCreation (line 45) | class SpectCreation:
method __init__ (line 46) | def __init__(self, pitch_shift, time_stretch, audio_sr, mel_args, verb...
method create_spects (line 89) | def create_spects(self):
method create_spect_piece (line 114) | def create_spect_piece(self, preprocessed_audio_folder, beat_path, dat...
class AudioPreprocessing (line 161) | class AudioPreprocessing(object):
method __init__ (line 162) | def __init__(
method preprocess_audio (line 215) | def preprocess_audio(self):
method process_audio_file (line 236) | def process_audio_file(self, dataset_name, audio_path):
function augment_audio_file (line 332) | def augment_audio_file(
function create_npz (line 383) | def create_npz(spect_dir, npz_file, augmentations, verbose):
function ints (line 396) | def ints(value):
function main (line 401) | def main(orig_audio_paths, pitch_shift, time_stretch, verbose):
FILE: launch_scripts/train.py
function main (line 13) | def main(args):
FILE: tests/test_inference.py
function test_File2Beat (line 10) | def test_File2Beat():
function test_Audio2Frames (line 18) | def test_Audio2Frames():
Condensed preview — 29 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (186K chars).
[
{
"path": ".github/workflows/pypi.yml",
"chars": 1850,
"preview": "# This workflow will upload a Python Package to PyPI when a release is created\n# For more information see: https://docs."
},
{
"path": ".gitignore",
"chars": 113,
"preview": "__pycache__/\n*.py[cod]\n*$py.class\n\ndata/\ncheckpoints/\nlightning_logs/\nwandb/\n.vscode/\nbeat_this.egg-info/\nbuild/\n"
},
{
"path": "CHANGELOG.md",
"chars": 813,
"preview": "# Changelog\n\nAll notable changes to this project are documented below.\n\nThe format is based on [Keep a Changelog](https:"
},
{
"path": "LICENSE",
"chars": 1113,
"preview": "MIT License\n\nCopyright (c) 2024 Institute of Computational Perception, JKU Linz, Austria\n\nPermission is hereby granted, "
},
{
"path": "README.md",
"chars": 16734,
"preview": "# Beat This!\nOfficial implementation of the beat tracker from the ISMIR 2024 paper \"[Beat This! Accurate Beat Tracking W"
},
{
"path": "beat_this/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "beat_this/cli.py",
"chars": 6365,
"preview": "#!/usr/bin/env python3\n\"\"\"\nBeat This! command line inference tool.\n\"\"\"\n\nimport argparse\nimport sys\nfrom pathlib import P"
},
{
"path": "beat_this/dataset/__init__.py",
"chars": 53,
"preview": "from beat_this.dataset.dataset import BeatDataModule\n"
},
{
"path": "beat_this/dataset/augment.py",
"chars": 8100,
"preview": "import numpy as np\nimport torch\n\n\ndef augment_pitchtempo(item, augmentations):\n \"\"\"\n Apply a randomly chosen pitch"
},
{
"path": "beat_this/dataset/dataset.py",
"chars": 23533,
"preview": "import concurrent.futures\nimport itertools\nimport json\nimport re\nfrom pathlib import Path\n\nimport numpy as np\nimport pan"
},
{
"path": "beat_this/dataset/mmnpz.py",
"chars": 4296,
"preview": "\"\"\"\nSupport for memory-mapping uncompressed .npz files.\n\"\"\"\n\nimport struct\nfrom collections.abc import Mapping\nfrom zipf"
},
{
"path": "beat_this/inference.py",
"chars": 12781,
"preview": "import inspect\n\nimport numpy as np\nimport soxr\nimport torch\nimport torch.nn.functional as F\n\nfrom beat_this.model.beat_t"
},
{
"path": "beat_this/model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "beat_this/model/beat_tracker.py",
"chars": 12582,
"preview": "\"\"\"\nModel definitions for the Beat This! beat tracker.\n\"\"\"\n\nimport contextlib\nfrom collections import OrderedDict\n\nimpor"
},
{
"path": "beat_this/model/loss.py",
"chars": 6089,
"preview": "\"\"\"\nLoss definitions for the Beat This! beat tracker.\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n\nclass MaskedBC"
},
{
"path": "beat_this/model/pl_module.py",
"chars": 14137,
"preview": "\"\"\"\nPytorch Lightning module, wraps a BeatThis model along with losses, metrics and\noptimizers for training.\n\"\"\"\n\nfrom c"
},
{
"path": "beat_this/model/postprocessor.py",
"chars": 8404,
"preview": "from concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom "
},
{
"path": "beat_this/model/roformer.py",
"chars": 4374,
"preview": "\"\"\"\nTransformer with rotary position embedding, adapted from Phil Wang's repository\nat https://github.com/lucidrains/BS-"
},
{
"path": "beat_this/preprocessing.py",
"chars": 1767,
"preview": "import numpy as np\nimport torch\nimport torchaudio\n\n\ndef load_audio(path, dtype=\"float64\"):\n try:\n waveform, sa"
},
{
"path": "beat_this/utils.py",
"chars": 4150,
"preview": "from itertools import chain\nfrom pathlib import Path\n\nimport numpy as np\n\n\ndef index_to_framewise(index, length):\n \"\""
},
{
"path": "beat_this_example.ipynb",
"chars": 6952,
"preview": "{\n \"nbformat\": 4,\n \"nbformat_minor\": 0,\n \"metadata\": {\n \"colab\": {\n \"provenance\": [],\n \"authorship_tag\":"
},
{
"path": "hubconf.py",
"chars": 283,
"preview": "dependencies = [\n \"torch\",\n \"torchaudio\",\n \"numpy\",\n \"rotary_embedding_torch\",\n \"einops\",\n \"soxr\",\n]\n\n"
},
{
"path": "launch_scripts/clean_checkpoints.py",
"chars": 1276,
"preview": "import argparse\nfrom pathlib import Path\n\nimport torch\n\n\ndef main(args):\n # check if output path exists\n if Path(a"
},
{
"path": "launch_scripts/compute_paper_metrics.py",
"chars": 11075,
"preview": "#!/usr/bin/env python3\nimport argparse\nfrom pathlib import Path\n\nimport numpy as np\nfrom pytorch_lightning import Traine"
},
{
"path": "launch_scripts/preprocess_audio.py",
"chars": 18310,
"preview": "#!/usr/bin/env python3\nimport argparse\nimport concurrent.futures\nimport os\nfrom pathlib import Path\nfrom zipfile import "
},
{
"path": "launch_scripts/train.py",
"chars": 9921,
"preview": "import argparse\nfrom pathlib import Path\n\nimport torch\nfrom pytorch_lightning import Trainer, seed_everything\nfrom pytor"
},
{
"path": "pyproject.toml",
"chars": 1042,
"preview": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"beat-this\"\nversion ="
},
{
"path": "requirements.txt",
"chars": 281,
"preview": "# This is a set of known working versions for inference, documented for a\n# distant future.\n# We recommend following the"
},
{
"path": "tests/test_inference.py",
"chars": 684,
"preview": "from pathlib import Path\n\nimport numpy as np\nimport soundfile as sf\nimport torch\n\nfrom beat_this.inference import Audio2"
}
]
About this extraction
This page contains the full source code of the CPJKU/beat_this GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 29 files (172.9 KB), approximately 41.2k tokens, and a symbol index with 181 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.