[
  {
    "path": ".github/workflows/pypi.yml",
    "content": "# This workflow will upload a Python Package to PyPI when a release is created\n# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries\n\n# This workflow uses actions that are not certified by GitHub.\n# They are provided by a third-party and are governed by\n# separate terms of service, privacy policy, and support\n# documentation.\n\nname: Upload Python Package\n\non:\n  release:\n    types: [published]\n\npermissions:\n  contents: read\n\njobs:\n  release-build:\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v5\n\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.x\"\n\n      - name: Build release distributions\n        run: |\n          python -m pip install build\n          python -m build\n\n      - name: Upload distributions\n        uses: actions/upload-artifact@v4\n        with:\n          name: release-dists\n          path: dist/\n\n  pypi-publish:\n    runs-on: ubuntu-latest\n    needs:\n      - release-build\n    permissions:\n      # IMPORTANT: this permission is mandatory for trusted publishing\n      id-token: write\n\n    # Dedicated environments with protections for publishing are strongly recommended.\n    # For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules\n    environment:\n      name: pypi\n      url: https://pypi.org/project/beat-this/${{ github.event.release.name }}\n\n    steps:\n      - name: Retrieve release distributions\n        uses: actions/download-artifact@v5\n        with:\n          name: release-dists\n          path: dist/\n\n      - name: Publish release distributions to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          packages-dir: dist/\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__/\n*.py[cod]\n*$py.class\n\ndata/\ncheckpoints/\nlightning_logs/\nwandb/\n.vscode/\nbeat_this.egg-info/\nbuild/\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nAll notable changes to this project are documented below.\n\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),\nand this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\n\n## [1.1.0] - 2026-04-14\n\n- Clarified installation instructions for madmom and mir_eval\n- Load checkpoints with `weights_only=True` when supported\n- Fix checkpoint downloads after server-side update\n- Provide separate `infer_beat_numbers()` function\n- Command-line tool: Support saving raw activations / logits\n- Training script: Support resuming from previous checkpoint\n- Migrate to pyproject.toml (thanks to @JacobLinCool)\n- Support non-CUDA accelerator chips (thanks to @tillt)\n- Published on PyPI (thanks to @MarvinSchenkel)\n\n## [1.0] - 2024-10-18\n\n- Initial release\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Institute of Computational Perception, JKU Linz, Austria\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Beat This!\nOfficial implementation of the beat tracker from the ISMIR 2024 paper \"[Beat This! Accurate Beat Tracking Without DBN Postprocessing](https://arxiv.org/abs/2407.21658)\" by Francesco Foscarin, Jan Schlüter and Gerhard Widmer.\n\n* [Inference](#inference)\n* [Available models](#available-models)\n* [Data](#data)\n* [Reproducing metrics from the paper](#reproducing-metrics-from-the-paper)\n* [Training](#training)\n* [Reusing the loss](#reusing-the-loss)\n* [Reusing the model](#reusing-the-model)\n* [Citation](#citation)\n\n\n## Inference\n\nTo predict beats for audio files, you can either use our command line tool or call the beat tracker from Python. Both have the same requirements unless you go for the online demo.\n\n### Online demo\n\nTo process a small set of audio files without installing anything, [open our example notebook in Google Colab](https://colab.research.google.com/github/CPJKU/beat_this/blob/main/beat_this_example.ipynb) and follow the instructions.\n\n### Requirements\n\nThe beat tracker requires Python with a set of packages installed:\n1. [Install PyTorch](https://pytorch.org/get-started/locally/) 2.0 or later following the instructions for your platform.\n2. Install further modules with `pip install tqdm einops soxr rotary-embedding-torch`. (If using conda, we still recommend pip. You may try installing `soxr-python` and `einops` from conda-forge, but `rotary-embedding-torch` is only on PyPI.)\n3. To read other audio formats than `.wav`, install `ffmpeg` or another supported backend for `torchaudio`. (`ffmpeg` can be installed via conda or via your operating system.)\n\nFinally, install our beat tracker with:\n```bash\npip install beat-this\n```\nFor the development version, use:\n```bash\npip install https://github.com/CPJKU/beat_this/archive/main.zip\n```\n\n### Command line\n\nAlong with the python package, a command line application called `beat_this` is installed. For a full documentation of the command line options, run:\n```bash\nbeat_this --help\n```\nThe basic usage is:\n```bash\nbeat_this path/to/audio.file -o path/to/output.beats\n```\nTo process multiple files, specify multiple input files or directories, and give an output directory instead:\n```bash\nbeat_this path/to/*.mp3 path/to/whole_directory/ -o path/to/output_directory\n```\nThe beat tracker will use the first GPU in your system by default, and fall back to CPU if PyTorch does not have CUDA access. With `--gpu=2`, it will use the third GPU, and with `--gpu=-1` it will force the CPU. For recent GPUs, passing `--float16` may improve speed.\nIf you have a lot of files to process, you can distribute the load over multiple processes by running the same command multiple times with `--touch-first`, `--skip-existing` and potentially different options for `--gpu`:\n```bash\nfor gpu in {0..3}; do beat_this input_dir -o output_dir --touch-first --skip-existing --gpu=$gpu & done\n```\nIf you want to use the DBN for postprocessing, add `--dbn`. The DBN parameters are the default ones from madmom. This requires installing the `madmom` package (with `pip install git+https://github.com/CPJKU/madmom.git`, as the current version on PyPI only supports Python<3.10 and numpy<1.20).\n\n### Python class\n\nIf you are a Python user, you can directly use the `beat_this.inference` module.\n\nFirst, instantiate an instance of the `File2Beats` class that encapsulates the model along with pre- and postprocessing:\n```python\nfrom beat_this.inference import File2Beats\nfile2beats = File2Beats(checkpoint_path=\"final0\", device=\"cuda\", dbn=False)\n```\nTo obtain a list of beats and downbeats for an audio file, run:\n```python\naudio_path = \"path/to/audio.file\"\nbeats, downbeats = file2beats(audio_path)\n```\nOptionally, you can produce a `.beats` file (e.g., for importing into [Sonic Visualizer](https://www.sonicvisualiser.org/)):\n```python\nfrom beat_this.utils import save_beat_tsv\noutpath = \"path/to/output.beats\"\nsave_beat_tsv(beats, downbeats, outpath)\n```\nIf you already have an audio tensor loaded, instead of `File2Beats`, use `Audio2Beats` and pass the tensor and its sample rate. We also provide `Audio2Frames` for framewise logits and `Spect2Frames` for spectrogram inputs.\n\n\n## Available models\n\nModels are available for manual download at [our cloud space](https://cloud.cp.jku.at/index.php/s/7ik4RrBKTS273gp), but will also be downloaded automatically by the above inference code. By default, the inference will use `final0`, but it is possible to select another model via a command line option (`--model`) or Python parameter (`checkpoint_path`).\n\nMain models:\n* `final0`, `final1`, `final2`: Our main model, trained on all data except the GTZAN dataset, with three different seeds. This corresponds to \"Our system\" in Table 2 of the paper. About 78 MB per model.\n* `small0`, `small1`, `small2`: A smaller model, again trained on all data except GTZAN, with three different seeds. This corresponds to \"smaller model\" in Table 2 of the paper. About 8.1 MB per model.\n* `single_final0`, `single_final1`, `single_final2`: Our main model, trained on the single split described in Section 4.1 of the paper, with three different seeds. This corresponds to \"Our system\" in Table 3 of the paper. About 78 MB per model.\n* `fold0`, `fold1`, `fold2`, `fold3`, `fold4`, `fold5`, `fold6`, `fold7`: Our main model, trained in the 8-fold cross-validation setting with a single seed per fold. This corresponds to \"Our\" in Table 1 of the paper. About 78 MB per model.\n\nOther models, available mainly for result reproducibility:\n* `hung0`, `hung1`, `hung2`: A model trained on all the data used by the \"Modeling Beats and Downbeats with a Time-Frequency Transformer\" system by Hung et al. (except GTZAN dataset), with three different seeds. This corresponds to \"limited to data of [10]\" in Table 2 of the paper.\n* the other models used for the ablation studies in Table 3, all trained with 3 seeds on the single split described in Section 4.1 of the paper:\n    * `single_notempoaug0`, `single_notempoaug1`, `single_notempoaug2`\n    * `single_nosumhead0`, `single_nosumhead1`, `single_nosumhead2`\n    * `single_nomaskaug0`, `single_nomaskaug1`, `single_nomaskaug2`\n    * `single_nopartialt0`, `single_nopartialt1`, `single_nopartialt2`\n    * `single_noshifttol0`, `single_noshifttol1`, `single_noshifttol2`\n    * `single_nopitchaug0`, `single_nopitchaug1`, `single_nopitchaug2`\n    * `single_noshifttolnoweights0`, `single_noshifttolnoweights1`, `single_noshifttolnoweights0`\n\n\nPlease be aware that the results may be unfairly good if you run inference on any file from the training datasets. For example, an evaluation with `final*` or `small*` can only be performed fairly on GTZAN or other datasets we didn't consider in our paper.\n\nIf you need to run an evaluation on some datasets we used other than GTZAN, consider targeting the validation part of the single split (with `single_final*`), or of the 8-fold cross-validation (with `fold*`).\n\nAll the models are provided as PyTorch Lightning checkpoints, stripped of the optimizer state to reduce their size. This is useful for reproducing the paper results or verifying the hyperparameters (stored in the checkpoint under `hyper_parameters` and `datamodule_hyper_parameters`).\nDuring inference, PyTorch Lighting is not used, and the checkpoints are converted and loaded into vanilla PyTorch modules.\n\n## Data\n\n### Annotations\nAll annotations we used to train our models are available [in a separate GitHub repo](https://github.com/CPJKU/beat_this_annotations). Note that if you want to obtain the exact paper results, you should use [version 1.0](https://github.com/CPJKU/beat_this_annotations/releases/tag/v1.0). Other releases with corrected annotations may be published in the future.\n\nTo use the annotations for training or evaluation, you first need to download and extract or clone the annotations repo to `data/annotations`:\n```bash\nmkdir -p data\ngit clone https://github.com/CPJKU/beat_this_annotations data/annotations\n# cd data/annotations; git checkout v1.0  # optional\n```\n### Spectrograms\nThe spectrograms used for training are released [as a Zenodo dataset](https://zenodo.org/records/13922116). They are distributed as a separate .zip file per dataset, each holding a .npz file with the spectrograms. For evaluation of the test set, download `gtzan.zip`; for training and evaluation of the validation set, download all (except `beat_this_annotations.zip`). Extract all .zip files into `data/audio/spectrograms`, so that you have, for example, `data/audio/spectrograms/gtzan.npz`. As an alternative, the code also supports directories of .npy files such as `data/audio/spectrograms/gtzan/gtzan_blues_00000/track.npy`, which you can obtain by unzipping `gtzan.npz`.\n\n### Recreating spectrograms\nIf you have access to the original audio files, or want to add another dataset, create a text file `data/audio_paths.tsv` that has, on each line, the name of a dataset, a tab character, and the path to the audio directory. The corresponding annotations must also be present under `data/annotations`. Install pandas and pedalboard:\n```bash\npip install pandas pedalboard\n```\nThen run:\n```bash\npython launch_scripts/preprocess_audio.py\n```\nIt will create monophonic 22 kHz wave files in `data/audio/mono_tracks`, convert those to spectrograms in `data/audio/spectrograms`, and create spectrogram bundles. Intermediary files are kept and will not be recreated when rerunning the script.\n\n\n## Reproducing metrics from the paper\n\n### Requirements\n\nIn addition to the [inference requirements](#requirements), computing evaluation metrics requires installing PyTorch Lightning, Pandas, and `mir_eval` (the latter from source, as the current version on PyPI only supports numpy<1.20).\n```bash\npip install pytorch_lightning pandas\npip install https://github.com/mir-evaluation/mir_eval/archive/main.zip\n```\nYou must also obtain and set up the annotations and spectrogram datasets [as indicated above](#data). Specifically, the GTZAN dataset suffices for commands that include `--datasplit test`, while all other datasets are required for commands that include `--datasplit val`.\n\n\n### Command line\n\n#### Compute results on the test set (GTZAN) corresponding to Table 2 in the paper.\n\nMain results for our system:\n```bash\npython launch_scripts/compute_paper_metrics.py --models final0 final1 final2 --datasplit test\n```\n\nSmaller model:\n```bash\npython launch_scripts/compute_paper_metrics.py --models small0 small1 small2 --datasplit test\n```\n\nHung data:\n```bash\npython launch_scripts/compute_paper_metrics.py --models hung0 hung1 hung2 --datasplit test\n```\n\nWith DBN (this requires installing the madmom package):\n```bash\npython launch_scripts/compute_paper_metrics.py --models final0 final1 final2 --datasplit test --dbn\n```\n\n#### Compute 8-fold cross-validation results, corresponding to Table 1 in the paper.\n\n```bash\npython launch_scripts/compute_paper_metrics.py --models fold0  fold1 fold2 fold3 fold4 fold5 fold6 fold7 --datasplit val --aggregation-type k-fold\n```\n\n#### Compute ablation studies on the validation set of the single split, correponding to Table 3 in the paper.\n\nOur system:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_final0 single_final1 single_final2 --datasplit val\n```\n\nNo sum head:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_nosumhead0 single_nosumhead1 single_nosumhead2 --datasplit val\n```\n\nNo tempo augmentation:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_notempoaug0 single_notempoaug1 single_notempoaug2 --datasplit val\n```\n\nNo mask augmentation:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_nomaskaug0 single_nomaskaug1 single_nomaskaug2 --datasplit val\n```\n\nNo partial transformers:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_nopartialt0 single_nopartialt1 single_nopartialt2 --datasplit val\n```\n\nNo shift tolerance:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_noshifttol0 single_noshifttol1 single_noshifttol2 --datasplit val\n```\n\nNo pitch augmentation:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_nopitchaug0 single_nopitchaug1 single_nopitchaug2 --datasplit val\n```\n\nNo shift tolerance and no weights:\n```bash\npython launch_scripts/compute_paper_metrics.py --models single_noshifttolnoweights0 single_noshifttolnoweights1 single_noshifttolnoweights2  --datasplit val\n```\n\n## Training\n\n### Requirements\n\nThe training requirements match the [evaluation requirements](#requirements-1) for the validation set. All 16 datasets and annotations must be [correctly set up](#data).\n\n### Command line\n\n#### Train models listed in Table 2 in the paper.\n\nMain results for our system (final0, final1, final2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-val\ndone\n```\n\nSmaller model (small0, small1, small2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-val --transformer-dim=128\ndone\n```\n\nHung data (hung0, hung1, hung2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-val --hung-data\ndone\n```\n\n#### Train models with 8-fold cross-validation, corresponding to Table 1 in the paper.\n\n```bash\nfor fold in {0..7}; do\n    python launch_scripts/train.py --fold=$fold\ndone\n```\n\n#### Train models for the ablation studies, corresponding to Table 3 in the paper.\n\nOur system (single_final0, single_final1, single_final2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed\ndone\n```\n\nNo sum head (single_nosumhead0, single_nosumhead1, single_nosumhead2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-sum-head\ndone\n```\n\nNo tempo augmentation (single_notempoaug0, single_notempoaug1, single_notempoaug2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-tempo-augmentation\ndone\n```\n\nNo mask augmentation (single_nomaskaug0, single_nomaskaug1, single_nomaskaug2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-mask-augmentation\ndone\n```\n\nNo partial transformers (single_nopartialt0, single_nopartialt1, single_nopartialt2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-partial-transformers\ndone\n```\n\nNo shift tolerance (single_noshifttol0, single_noshifttol1, single_noshifttol2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --loss weighted_bce\ndone\n```\n\nNo pitch augmentation (single_nopitchaug0, single_nopitchaug1, single_nopitchaug2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --no-pitch-augmentation\ndone\n```\n\nNo shift tolerance and no weights (single_noshifttolnoweights0, single_noshifttolnoweights1, single_noshifttolnoweights2):\n```bash\nfor seed in 0 1 2; do\n    python launch_scripts/train.py --seed=$seed --loss bce\ndone\n```\n\n\n## Reusing the loss\n\nTo reuse our shift-invariant binary cross-entropy loss, just copy out the `ShiftTolerantBCELoss` class from [`loss.py`](beat_this/model/loss.py), it does not have any dependencies.\n\n\n## Reusing the model\n\nTo reuse the BeatThis model, you have multiple options:\n\n### From the package\n\nWhen installing the `beat_this` package, you can directly import the model class:\n```\nfrom beat_this.model.beat_tracker import BeatThis\n```\nInstantiating this class will give you an untrained model from spectrograms to frame-wise beat and downbeat logits. For a pretrained model, use `load_model`:\n```\nfrom beat_this.inference import load_model\nbeat_this = load_model('final0', device='cuda')\n```\n### From torch.hub\n\nTo quickly try the model without installing the package, just install the [requirements for inference](#requirements) and do:\n```\nimport torch\nbeat_this = torch.hub.load('CPJKU/beat_this', 'beat_this', 'final0', device='cuda')\n```\n### Copy and paste\n\nTo copy the BeatThis model into your own project, you will need the [`beat_tracker.py`](beat_this/model/beat_tracker.py) and [`roformer.py`](beat/this/model/roformer.py) files. If you remove the `BeatThis.state_dict()` and `BeatThis._load_from_state_dict()` methods that serve as a workaround for compiled models, then there are no other internal dependencies, only external dependencies (`einops`, `rotary-embedding-torch`).\n\n\n## Citation\n\n```bibtex\n@inproceedings{foscarin2024beatthis,\n    author = {Francesco Foscarin and Jan Schl{\\\"u}ter and Gerhard Widmer},\n    title = {Beat this! Accurate beat tracking without {DBN} postprocessing},\n    year = 2024,\n    month = nov,\n    booktitle = {Proceedings of the 25th International Society for Music Information Retrieval Conference (ISMIR)},\n    address = {San Francisco, CA, United States},\n}\n```\n"
  },
  {
    "path": "beat_this/__init__.py",
    "content": ""
  },
  {
    "path": "beat_this/cli.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nBeat This! command line inference tool.\n\"\"\"\n\nimport argparse\nimport sys\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\n\ntry:\n    import tqdm\nexcept ImportError:\n    tqdm = None\n\nfrom beat_this.inference import File2File, load_audio\nfrom beat_this.utils import save_beat_tsv\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser(\n        description=\"Detects beats in given audio files with a Beat This! model.\"\n    )\n    parser.add_argument(\n        \"inputs\",\n        type=str,\n        nargs=\"+\",\n        help=\"An audio file to process, or a directory of such files. Can be given multiple times.\",\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        help=\"Name, path or URL of checkpoint to use, will be downloaded if needed (default:%(default)s).\",\n        default=\"final0\",\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=str,\n        default=None,\n        help=\"Output file name for a single input file, or output directory for multiple input files. If omitted, outputs are saved next to each input file by replacing or appending a suffix (see --suffix and --append).\",\n    )\n    parser.add_argument(\n        \"--suffix\",\n        \"-s\",\n        type=str,\n        default=\".beats\",\n        help=\"Suffix for output file names (default: %(default)s). Also see --append. Ignored if an explicit output file name is given.\",\n    )\n    parser.add_argument(\n        \"--append\",\n        action=\"store_true\",\n        help=\"If given, append suffix to output file names instead of replacing the existing suffix. Ignored if an explicit output file name is given.\",\n    )\n    parser.add_argument(\n        \"--skip-existing\",\n        action=\"store_true\",\n        help=\"If given, do not overwrite existing output files, but skip them.\",\n    )\n    parser.add_argument(\n        \"--touch-first\",\n        action=\"store_true\",\n        help=\"If given, create empty output file before processing. Combined with --skip-existing, allows to run multiple processes in parallel on the same set of files.\",\n    )\n    parser.add_argument(\n        \"--dbn\",\n        default=False,\n        action=argparse.BooleanOptionalAction,\n        help=\"Override the option to use madmom's postprocessing DBN.\",\n    )\n    parser.add_argument(\n        \"--gpu\",\n        type=int,\n        default=0,\n        help=\"Which GPU to use (not the number of GPUs), or -1 for CPU. Ignored if CUDA is not available. (default: %(default)s)\",\n    )\n    parser.add_argument(\n        \"--float16\",\n        action=\"store_true\",\n        help=\"If given, uses half precision floating point arithmetics. Required for flash attention on GPU. (default: %(default)s)\",\n    )\n    parser.add_argument(\n        \"--activations\",\n        action=\"store_true\",\n        help=\"If given, saves the raw activations with a .npy suffix.\",\n    )\n    return parser\n\n\ndef derive_output_path(input_path, suffix, append, output=None, parent=None):\n    \"\"\"\n    Determine the output file name for `input_path` using the given\n    suffix. If given, `output` is the base directory for outputs, and\n    `parent` is the directory that was given on the command line.\n    \"\"\"\n    # output directory\n    if output is None:\n        output_path = input_path\n    else:\n        if parent is not None:\n            input_path = input_path.relative_to(parent)\n        else:\n            input_path = input_path.name\n        output_path = output / input_path\n    # suffix\n    if append:\n        return output_path.parent / (output_path.name + suffix)\n    else:\n        return output_path.with_suffix(suffix)\n\n\ndef run(\n    inputs,\n    model,\n    output,\n    suffix,\n    append,\n    skip_existing,\n    touch_first,\n    dbn,\n    gpu,\n    float16,\n    activations,\n):\n    # determine device\n    if torch.cuda.is_available() and gpu >= 0:\n        device = torch.device(f\"cuda:{gpu}\")\n    else:\n        device = torch.device(\"cpu\")\n\n    # prepare model\n    file2file = File2File(model, device, float16, dbn)\n    if activations:\n\n        def process(audiofile, outfile):\n            wav, sr = load_audio(audiofile)\n            spect = file2file.signal2spect(wav, sr)\n            beat_logits, downbeat_logits = file2file.spect2frames(spect)\n            np.save(\n                outfile.with_suffix(\".npy\"),\n                np.vstack([beat_logits.cpu().numpy(), downbeat_logits.cpu().numpy()]),\n            )\n            beats, downbeats = file2file.frames2beats(beat_logits, downbeat_logits)\n            save_beat_tsv(beats, downbeats, outfile)\n\n    else:\n        process = file2file\n\n    # process inputs\n    inputs = [Path(item) for item in inputs]\n    if output is not None:\n        output = Path(output)\n    if len(inputs) == 1 and not inputs[0].is_dir():\n        # special case: single input file\n        if output is None or output.is_dir():\n            output = derive_output_path(inputs[0], suffix, append, output)\n        process(inputs[0], output)\n    else:\n        # multiple inputs: first collect tasks so we can have a progress bar\n        tasks = []\n        for item in inputs:\n            if item.is_dir():\n                for fn in item.rglob(\"*\"):\n                    if not fn.name.endswith(suffix) and not fn.is_dir():\n                        output_path = derive_output_path(\n                            fn, suffix, append, output, parent=item\n                        )\n                        if not skip_existing or not output_path.exists():\n                            tasks.append((fn, output_path))\n            else:\n                tasks.append((item, derive_output_path(item, suffix, append, output)))\n        # then process all of them\n        if tqdm is not None:\n            tasks = tqdm.tqdm(tasks)\n        for item, output in tasks:\n            if touch_first:\n                try:\n                    output.touch(exist_ok=not skip_existing)\n                except FileExistsError:\n                    continue\n            elif skip_existing and output.exists():\n                continue\n            try:\n                process(item, output)\n            except Exception:\n                print(\n                    f'Could not process \"{item}\". Rerun with this file alone for details.',\n                    file=sys.stderr,\n                )\n\n\ndef main():\n    run(**vars(get_parser().parse_args()))\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "beat_this/dataset/__init__.py",
    "content": "from beat_this.dataset.dataset import BeatDataModule\n"
  },
  {
    "path": "beat_this/dataset/augment.py",
    "content": "import numpy as np\nimport torch\n\n\ndef augment_pitchtempo(item, augmentations):\n    \"\"\"\n    Apply a randomly chosen pitch or tempo augmentation to the item.\n\n    Parameters:\n    item: dict\n        A dictionary representing the item to be augmented. It should contain the following keys:\n        - 'spect_path': The path to the the unaugmented spectrogram file.\n        If pitch or tempo augmentation is applied, the 'spect_path' key will be updated.\n\n    augmentations: dict\n        A dictionary containing the augmentations to be applied. It can contain either or both of the following keys:\n        - 'pitch': A dictionary with 'min' and 'max' keys specifying the range of pitch shifting in semitones.\n        - 'tempo': A dictionary with 'min' and 'max' keys specifying the range of time stretching factors.\n\n    Returns:\n    item: dict\n        The item after applying the augmentation. If a pitch or tempo augmentation was applied, the 'spect_path' key\n        and the annotations will be updated.\n    \"\"\"\n    # Handle pitch and tempo augmentations\n    if \"pitch\" in augmentations and \"tempo\" in augmentations:\n        # if both pitch and tempo are enabled, pick one of them\n        if np.random.randint(2) == 0:\n            # pitch\n            item = augment_pitch(item, augmentations[\"pitch\"])\n        else:\n            # tempo\n            item = augment_tempo(item, augmentations[\"tempo\"])\n    elif \"pitch\" in augmentations:\n        item = augment_pitch(item, augmentations[\"pitch\"])\n    elif \"tempo\" in augmentations:\n        item = augment_tempo(item, augmentations[\"tempo\"])\n\n    return item\n\n\ndef augment_pitch(item, pitch_params):\n    \"\"\"Apply pitch shifting to the item.\"\"\"\n    semitones = np.random.randint(pitch_params[\"min\"], pitch_params[\"max\"] + 1)\n    item = shift_filename(item, semitones)\n    item = shift_annotations(item, semitones)\n    return item\n\n\ndef augment_tempo(item, tempo_params):\n    \"\"\"Apply time stretching to the item.\"\"\"\n    percentage = np.random.choice(\n        np.arange(tempo_params[\"min\"], tempo_params[\"max\"] + 1, tempo_params[\"stride\"])\n    )\n    item = stretch_filename(item, percentage)\n    item = stretch_annotations(item, percentage)\n    return item\n\n\ndef stretch_annotations(item, percentage):\n    \"\"\"Apply time stretching to the item's annotations.\"\"\"\n    if not percentage:\n        return item\n    # percentage is the amount by which the *tempo* changes\n    factor = 1.0 + percentage / 100\n    item = dict(item)\n    item[\"beat_time\"] = item[\"beat_time\"] / factor\n    return item\n\n\ndef shift_annotations(item, semitones):\n    \"\"\"Apply pitch shifting to the item's annotations.\"\"\"\n    return item\n\n\ndef stretch_filename(item, percentage):\n    \"\"\"Derive filename of precomputed time stretched version.\"\"\"\n    spect_path = item[\"spect_path\"]\n    if percentage:\n        stem = spect_path.stem + f\"_ts{percentage}\"\n        spect_path = spect_path.with_stem(stem)\n    return {**item, \"spect_path\": spect_path}\n\n\ndef shift_filename(item, semitones):\n    \"\"\"Derive filename of precomputed pitch shifted version.\"\"\"\n    spect_path = item[\"spect_path\"]\n    if semitones:\n        stem = spect_path.stem + f\"_ps{semitones}\"\n        spect_path = spect_path.with_stem(stem)\n    return {**item, \"spect_path\": spect_path}\n\n\ndef number_of_precomputed_augmentations(augmentations):\n    \"\"\"Return the number of augmentations that correspond to a precomputed file.\"\"\"\n    counter = 1\n    for method, params in augmentations.values():\n        if method in (\"pitch\"):\n            counter += params[\"max\"] - params[\"min\"]\n        elif method in (\"tempo\"):\n            counter += (params[\"max\"] - params[\"min\"]) // params[\"stride\"]\n    return counter\n\n\ndef precomputed_augmentation_filenames(augmentations, ext=\"npy\"):\n    \"\"\"Return the filenames of the precomputed augmentations.\n\n    Parameters:\n    augmentations: dict\n        A dictionary containing the augmentations to be applied. It can contain either or both of the following keys:\n        - 'pitch': A dictionary with 'min' and 'max' keys specifying the range (including boundaries) of pitch shifting in semitones.\n        - 'tempo': A dictionary with 'min' and 'max' keys specifying the range (including boundaries) of time stretching factors; and a 'stride' key specifying the step size.\n    \"\"\"\n    filenames = [f\"track.{ext}\"]\n    for method, params in augmentations.items():\n        if method == \"pitch\":\n            for semitones in range(params[\"min\"], params[\"max\"] + 1):\n                if semitones == 0:\n                    continue\n                filenames.append(f\"track_ps{semitones}.{ext}\")\n        elif method == \"tempo\":\n            for percentage in range(params[\"min\"], params[\"max\"] + 1, params[\"stride\"]):\n                if percentage == 0:\n                    continue\n                filenames.append(f\"track_ts{percentage}.{ext}\")\n    return filenames\n\n\ndef augment_mask_(spect, augmentations: dict, fps: int):\n    \"\"\"\n    Apply the given masking operations to the spectrogram. The spectrogram is modified in place.\n\n    Parameters:\n    spect: ndarray\n        The input spectrogram to which the mask will be applied. It is a 2D array where the first dimension\n        represents time frames and the second dimension represents frequency bins.\n\n    augmentations: dict\n        A dictionary containing all the augmentations. If there is no \"mask\" key, this function returns the\n        unmodified spectrogram. If \"mask\" key is present, the value is another dictionary which must include\n        the following keys:\n        - 'kind': The type of mask to apply. Choices: 'permute' and 'zero'.\n        - 'min_count' and 'max_count': The minimum and maximum number of times the mask should be applied.\n        - 'min_len' and 'max_len': The minimum and maximum length of the mask, expressed in seconds.\n        - 'min_parts' and 'max_parts': The minimum and maximum number of parts in which each masked section is segmented.\n            These are then randomly reordered. If 'kind'='permute' this parameter is not used.\n\n    fps: int\n        The frames per second of the audio. This is used to convert 'min_len' and 'max_len' from seconds to frames.\n\n    Returns:\n    spect: ndarray\n        The spectrogram after applying the mask.\n\n    \"\"\"\n    if \"mask\" in augmentations:\n        mask_params = augmentations[\"mask\"]\n        count = np.random.randint(\n            mask_params[\"min_count\"], mask_params[\"max_count\"] + 1\n        )\n        # convert min_len and max_len in frames\n        min_len = int(mask_params[\"min_len\"] * fps)\n        max_len = int(mask_params[\"max_len\"] * fps)\n        # apply the masking a number of time specified by count\n        for _ in range(count):\n            length = np.random.randint(min_len, max_len + 1)\n            start = np.random.randint(0, len(spect) - length)\n            apply_mask_excerpt(\n                spect[start : start + length],\n                mask_params[\"kind\"],\n                mask_params[\"min_parts\"],\n                mask_params[\"max_parts\"],\n            )\n    return spect\n\n\ndef apply_mask_excerpt(excerpt, kind, min_parts, max_parts):\n    \"\"\"Apply a mask operation of the given kind in-place to the given tensor.\"\"\"\n    if kind == \"permute\":\n        num_parts = np.random.randint(min_parts, max_parts + 1)\n        choices = len(excerpt)\n        num_parts = min(num_parts, choices + 1)\n        positions = np.random.choice(choices, num_parts - 1, replace=False)\n        positions.sort()\n        if isinstance(excerpt, np.ndarray):\n            parts = np.split(excerpt, positions)\n        else:\n            parts = (\n                [excerpt[: positions[0]]]\n                + [excerpt[a:b] for a, b in zip(positions[:-1], positions[1:])]\n                + [excerpt[positions[-1] :]]\n            )\n        parts = [parts[idx] for idx in np.random.permutation(num_parts)]\n        if isinstance(excerpt, np.ndarray):\n            excerpt[:] = np.concatenate(parts)\n        else:\n            excerpt[:] = torch.cat(parts)\n    elif kind == \"zero\":\n        excerpt[:] = 0\n    else:\n        raise ValueError(f\"Unsupported mask operation: {kind}\")\n"
  },
  {
    "path": "beat_this/dataset/dataset.py",
    "content": "import concurrent.futures\nimport itertools\nimport json\nimport re\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport pytorch_lightning as pl\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom beat_this.dataset.augment import (\n    augment_mask_,\n    augment_pitchtempo,\n    precomputed_augmentation_filenames,\n)\nfrom beat_this.utils import index_to_framewise\n\nfrom .mmnpz import MemmappedNpzFile\n\n\nclass BeatTrackingDataset(Dataset):\n    \"\"\"\n    A PyTorch Dataset for beat tracking. This dataset loads preprocessed spectrograms and beat annotations\n    from a given data folder and provides them for training or evaluation.\n\n    Args:\n        item_names (list of str): A list of dataset items such as \"gtzan/gtzan_rock_00099\".\n        data_folder (Path or str): The base folder where the data is stored.\n        spect_fps (int, optional): The frames per second of the spectrograms. Defaults to 50.\n        train_length (int, optional): The length of the training sequences in frames. If None the entire piece is used. Defaults to 1500.\n        deterministic (bool, optional): If True, the dataset always returns the same sequence for a given index.\n            Defaults to False.\n        augmentations (dict, optional): A dictionary of data augmentations to apply. Possible keys are \"tempo\", \"pitch\", and \"mask\". Defaults to an empty dictionary.\n    \"\"\"\n\n    def __init__(\n        self,\n        item_names: list[str],\n        data_folder,\n        spect_fps=50,\n        train_length=1500,\n        deterministic=False,\n        augmentations={},\n        length_based_oversampling_factor=0,\n    ):\n        self.spect_basepath = data_folder / \"audio\" / \"spectrograms\"\n        self.annotation_basepath = data_folder / \"annotations\"\n        self.fps = spect_fps\n        self.train_length = train_length\n        self.deterministic = deterministic\n        self.augmentations = augmentations\n        self.length_based_oversampling_factor = length_based_oversampling_factor\n        datasets = sorted(set(name.split(\"/\", 1)[0] for name in item_names))\n        # load dataset info\n        self.dataset_info = self._load_dataset_infos(datasets)\n        # load .npz spectrogram bundles, if any\n        self.spects = self._load_spect_bundles(datasets)\n        # load the annotations in parallel\n        with concurrent.futures.ThreadPoolExecutor() as executor:\n            items = executor.map(self._load_dataset_item, item_names)\n        items = [item for item in items if item is not None]\n        if self.length_based_oversampling_factor and self.train_length is not None:\n            # oversample the dataset according to the audio lengths, so that long pieces are sampled more often\n            oversampled_items = []\n            for item in items:\n                oversampling_factor = np.round(\n                    self.length_based_oversampling_factor\n                    * len(self._get_spect(item))\n                    / self.train_length\n                ).astype(int)\n                oversampling_factor = max(oversampling_factor, 1)\n                oversampled_items.extend(itertools.repeat(item, oversampling_factor))\n            print(\n                f\"Training set oversampled from {len(items)} to {len(oversampled_items)} excerpts.\"\n            )\n            items = oversampled_items\n        self.items = items\n\n    def _load_dataset_infos(self, datasets):\n        dataset_info = {}\n        for dataset in datasets:\n            with open(self.annotation_basepath / dataset / \"info.json\") as f:\n                dataset_info[dataset] = json.load(f)\n        return dataset_info\n\n    def _load_spect_bundles(self, datasets):\n        spects = {}\n        for dataset in datasets:\n            npz_file = (self.spect_basepath / dataset).with_suffix(\".npz\")\n            if npz_file.exists():\n                spects[dataset] = MemmappedNpzFile(npz_file)\n        return spects\n\n    def _load_dataset_item(self, item_name):\n        # stop if not all the augmented audio files are there\n        dataset, remainder = item_name.split(\"/\", 1)\n        for aug_filename in precomputed_augmentation_filenames(self.augmentations):\n            if (f\"{remainder}/{aug_filename[:-4]}\") not in self.spects.get(\n                dataset, ()\n            ) and not (self.spect_basepath / item_name / aug_filename).exists():\n                print(\n                    f\"Skipping {item_name} because not all necessary spectrograms are there.\"\n                )\n                return\n\n        # load beat and produce a default if beat values are not found\n        dataset, stem = item_name.split(\"/\", 1)\n        annotation_path = (\n            self.annotation_basepath\n            / dataset\n            / \"annotations\"\n            / \"beats\"\n            / (stem + \".beats\")\n        )\n        beat_annotation = np.loadtxt(annotation_path)\n        if beat_annotation.ndim == 2:\n            beat_time = beat_annotation[:, 0]\n            beat_value = beat_annotation[:, 1].astype(int)\n        else:\n            beat_time = beat_annotation\n            beat_value = np.zeros_like(beat_time, dtype=np.int32)\n\n        # stop if the annotations that are supposed to be there are not there\n        if self.dataset_info[dataset][\"has_downbeats\"]:\n            if beat_annotation.ndim != 2:\n                print(\n                    f\"Skipping {item_name} because it has {beat_annotation.ndim} columns but downbeat is supposed to be there.\"\n                )\n                return\n\n        # create a downbeat mask to handle the case where the downbeat is not annotated\n        downbeat_mask = self.dataset_info[dataset][\"has_downbeats\"]\n        # take care of different subsections of rwc for the dataset name\n        if dataset == \"rwc\":\n            dataset = \"rwc_\" + stem.split(\"_\", 2)[1]\n        return {\n            \"spect_path\": Path(item_name) / \"track.npy\",\n            \"beat_time\": beat_time,\n            \"beat_value\": beat_value,\n            \"downbeat_mask\": downbeat_mask,\n            \"dataset\": dataset,\n        }\n\n    def _get_spect(self, item):\n        try:\n            dataset, filename = str(item[\"spect_path\"]).split(\"/\", 1)\n            spect = self.spects[dataset][filename[:-4]]\n        except KeyError:\n            spect = np.load(self.spect_basepath / item[\"spect_path\"], mmap_mode=\"r\")\n        return spect\n\n    def get_frame_count(self, index):\n        \"\"\"Return number of frames of given item.\"\"\"\n        return len(self._get_spect(self.items[index]))\n\n    def get_beat_count(self, index):\n        \"\"\"Return number of beats (including downbeats) of given item.\"\"\"\n        return len(self.items[index][\"beat_time\"])\n\n    def get_downbeat_count(self, index):\n        \"\"\"Return number of downbeats of given item.\"\"\"\n        return (self.items[index][\"beat_value\"] == 1).sum()\n\n    def __len__(self):\n        return len(self.items)\n\n    def __getitem__(self, index):\n        if isinstance(index, (int, np.int64)):  # when index is a single int\n            item = self.items[index]\n\n            # select a pitch shift and time stretch\n            item = augment_pitchtempo(item, self.augmentations)\n\n            # load spectrogram\n            spect = self._get_spect(item)\n\n            # define the excerpt to use\n            original_length = len(spect)\n            if self.train_length is not None:\n                longer = original_length - self.train_length\n            else:\n                longer = 0\n            if longer > 0:  # if the piece is longer than the desired length\n                if self.deterministic:\n                    # select the middle of the excerpt\n                    start_frame = longer // 2\n                else:\n                    start_frame = np.random.randint(0, longer)\n                end_frame = start_frame + self.train_length\n            else:\n                start_frame = 0\n                end_frame = original_length\n\n            # obtain a view of the excerpt\n            spect = spect[start_frame:end_frame]\n\n            if \"mask\" in self.augmentations:\n                # copy the spectrogram and apply mask augmentation\n                spect = np.copy(spect)\n                spect = augment_mask_(spect, self.augmentations, self.fps)\n            else:\n                # only ensure we have a writeable array (so PyTorch is happy)\n                spect = np.require(spect, requirements=\"W\")\n\n            # prepare annotations\n            (\n                framewise_truth_beat,\n                framewise_truth_downbeat,\n                truth_orig_beat,\n                truth_orig_downbeat,\n            ) = prepare_annotations(item, start_frame, end_frame, self.fps)\n\n            # restructure the item dict with the correct training information\n            item = {\n                \"spect\": spect,\n                \"spect_path\": str(item[\"spect_path\"]),\n                \"dataset\": item[\"dataset\"],\n                \"start_frame\": start_frame,\n                \"truth_beat\": framewise_truth_beat,\n                \"truth_downbeat\": framewise_truth_downbeat,\n                \"downbeat_mask\": torch.as_tensor(item[\"downbeat_mask\"]),\n                \"padding_mask\": (\n                    np.ones(self.train_length, dtype=bool)\n                    if self.train_length is not None\n                    else np.ones(original_length, dtype=bool)\n                ),\n                \"truth_orig_beat\": truth_orig_beat,\n                \"truth_orig_downbeat\": truth_orig_downbeat,\n            }\n\n            # pad all framewise tensors if needed\n            if longer < 0:\n                item[\"spect\"] = np.pad(\n                    item[\"spect\"], [(0, -longer), (0, 0)], constant_values=0\n                )\n                for k in \"truth_beat\", \"truth_downbeat\":\n                    item[k] = np.pad(item[k], [(0, -longer)], constant_values=0)\n                item[\"padding_mask\"][longer:] = 0\n            return item\n\n        else:  # when index is a list of ints\n            return [self[i] for i in index]\n\n\nclass BeatDataModule(pl.LightningDataModule):\n    \"\"\"\n    A PyTorch Lightning DataModule for beat tracking. This DataModule handles the loading and preprocessing of the\n    BeatTrackingDataset and prepares it for use with a PyTorch Lightning model.\n    It can produce cross-validation or single  train/val/test splits.\n\n    Args:\n        data_dir (Path or str): The parent directory where the data (spectrograms and beat labels) is stored.\n        batch_size (int, optional): The size of the batches to be generated by the DataLoader. Defaults to 8.\n        train_length (int, optional): The length of the subsequences in frames. If None, the entire pieces are returner. Defaults to 1500.\n        num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 20.\n        augmentations (dict, optional): A dictionary of data augmentations to apply. Defaults to {\"pitch\": {\"min\": -5, \"max\": 6}, \"time\": {\"min\": -20, \"max\": 20, \"stride\": 4}}.\n        test_dataset (str, optional): The name of the dataset to use for testing. Defaults to \"gtzan\".\n        hung_data (bool, optional): If True, only use the datasets from the Hung et al. paper for training; validation is still on all datasets. Defaults to False.\n        no_val (bool, optional): If True, train on all train+val data and do not use a validation set; for compatibility reason, the validation metrics are still computed, but are not meaningful. Defaults to False.\n        spect_fps (int, optional): The frames per second of the spectrograms. Defaults to 50.\n        length_based_oversampling_factor (int, optional): The factor by which to oversample the train dataset based on sequence length. Defaults to 0.\n        fold (int, optional): The fold number for cross-validation. If None, the single split is used. Defaults to None.\n        predict_datasplit (str, optional): The split to use for prediction. Prediction dataset is always full pieces. Defaults to \"test\".\n    \"\"\"\n\n    def __init__(\n        self,\n        data_dir,\n        batch_size=8,\n        train_length=1500,\n        num_workers=20,\n        augmentations={\n            \"pitch\": {\"min\": -5, \"max\": 6},\n            \"tempo\": {\"min\": -20, \"max\": 20, \"stride\": 4},\n        },\n        test_dataset=\"gtzan\",\n        hung_data=False,\n        no_val=False,\n        spect_fps=50,\n        length_based_oversampling_factor=0,\n        fold=None,\n        predict_datasplit=\"test\",\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        self.initialized = {}\n        # remember all arguments\n        self.data_dir = Path(data_dir)\n        self.batch_size = batch_size\n        self.train_length = train_length\n        self.num_workers = num_workers\n        if not set(augmentations.keys()).issubset({\"mask\", \"pitch\", \"tempo\"}):\n            raise ValueError(f\"Unsupported augmentations: {augmentations.keys()}\")\n        self.augmentations = augmentations\n        self.test_set_name = test_dataset\n        self.hung_data = hung_data\n        self.no_val = no_val\n        self.spect_fps = spect_fps\n        self.length_based_oversampling_factor = length_based_oversampling_factor\n        self.fold = fold\n        self.predict_datasplit = predict_datasplit\n\n    def setup(self, stage):\n        if self.initialized.get(stage, False):\n            return\n\n        # set up the paths\n        annotation_dir = self.data_dir / \"annotations\"\n\n        # load train/val splits\n        if stage in (\"fit\", \"validate\"):\n            self.val_items = []\n            self.train_items = []\n            split_file = \"8-folds.split\" if self.fold is not None else \"single.split\"\n            for dataset_dir in annotation_dir.iterdir():\n                if not dataset_dir.is_dir() or not (dataset_dir / split_file).exists():\n                    continue\n                dataset = dataset_dir.name\n                if dataset == self.test_set_name:\n                    continue\n                split = pd.read_csv(\n                    dataset_dir / split_file,\n                    header=None,\n                    names=[\"piece\", \"part\"],\n                    sep=\"\\t\",\n                )\n                if self.fold is not None:\n                    # CV: use given fold for validation, rest for training\n                    self.val_items.extend(\n                        f\"{dataset}/{stem}\"\n                        for stem in split.piece[split.part == self.fold]\n                    )\n                    self.train_items.extend(\n                        f\"{dataset}/{stem}\"\n                        for stem in split.piece[split.part != self.fold]\n                    )\n                else:\n                    # single split: marked as val and train\n                    self.val_items.extend(\n                        f\"{dataset}/{stem}\" for stem in split.piece[split.part == \"val\"]\n                    )\n                    self.train_items.extend(\n                        f\"{dataset}/{stem}\"\n                        for stem in split.piece[split.part == \"train\"]\n                    )\n            if self.no_val:\n                # Train on all available data (excluding the test set).\n                # For compatibility, validation metrics are still computed\n                # on the original validation set now included in training.\n                self.train_items.extend(self.val_items)\n            if self.hung_data:\n                # Use the training datasets from MODELING BEATS AND DOWNBEATS\n                # WITH A TIME-FREQUENCY TRANSFORMER (for comparability, the\n                # validation set stays the same, with all datasets).\n                regexp = re.compile(\n                    \"^(hainsworth/|ballroom/|hjdb/|beatles/|rwc/rwc_popular|simac/|smc/|harmonix/|).*$\"\n                )\n                self.train_items = [\n                    item for item in self.train_items if regexp.match(item)\n                ]\n            self.val_items.sort()\n            self.train_items.sort()\n\n        # load validation set\n        if stage in (\"fit\", \"validate\"):\n            self.val_dataset = BeatTrackingDataset(\n                self.val_items,\n                deterministic=True,\n                augmentations={},\n                train_length=self.train_length,\n                data_folder=self.data_dir,\n                spect_fps=self.spect_fps,\n            )\n            print(\n                \"Validation set:\",\n                len(self.val_dataset),\n                \"items from:\",\n                *sorted(set(item.split(\"/\", 1)[0] for item in self.val_items)),\n            )\n            self.initialized[\"validate\"] = True\n\n        # load training set\n        if stage == \"fit\":\n            self.train_dataset = BeatTrackingDataset(\n                self.train_items,\n                deterministic=False,\n                augmentations=self.augmentations,\n                train_length=self.train_length,\n                data_folder=self.data_dir,\n                spect_fps=self.spect_fps,\n                length_based_oversampling_factor=self.length_based_oversampling_factor,\n            )\n            print(\n                \"Training set:\",\n                len(self.train_dataset),\n                \"items from:\",\n                *sorted(set(item.split(\"/\", 1)[0] for item in self.train_items)),\n            )\n            self.initialized[\"fit\"] = True\n\n        # load test set\n        if stage == \"test\":\n            test_annotations_dir = (\n                annotation_dir / self.test_set_name / \"annotations\" / \"beats\"\n            )\n            self.test_items = sorted(\n                f\"{self.test_set_name}/{item.stem}\"\n                for item in test_annotations_dir.glob(\"*.beats\")\n            )\n            self.test_dataset = BeatTrackingDataset(\n                self.test_items,\n                deterministic=True,\n                augmentations={},\n                train_length=None,\n                data_folder=self.data_dir,\n                spect_fps=self.spect_fps,\n            )\n            print(\n                \"Test set:\", len(self.test_dataset), \"items from:\", self.test_set_name\n            )\n            self.initialized[\"test\"] = True\n\n        # load prediction set\n        if stage == \"predict\":\n            if self.predict_datasplit == \"test\":\n                self.setup(\"test\")\n                # we can directly use the test dataset for predictions\n                self.predict_dataset = self.test_dataset\n            else:\n                if self.predict_datasplit == \"train\":\n                    self.setup(\"fit\")\n                    items = self.train_items\n                elif self.predict_datasplit == \"val\":\n                    self.setup(\"validate\")\n                    items = self.val_items\n                # for prediction, we want to use full items (train_length=None)\n                self.predict_dataset = BeatTrackingDataset(\n                    items,\n                    deterministic=True,\n                    augmentations={},\n                    train_length=None,\n                    data_folder=self.data_dir,\n                    spect_fps=self.spect_fps,\n                )\n\n    def train_dataloader(self):\n        return DataLoader(\n            self.train_dataset,\n            num_workers=self.num_workers,\n            batch_size=self.batch_size,\n            shuffle=True,\n            drop_last=True,\n            pin_memory=True,\n        )\n\n    def val_dataloader(self):\n        # Warning: for performances, this only runs on the middle excerpt of the long pieces\n        # The paper results are computed after training in the predict script\n        return DataLoader(\n            self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers\n        )\n\n    def test_dataloader(self):\n        return DataLoader(self.test_dataset, batch_size=1, num_workers=self.num_workers)\n\n    def predict_dataloader(self):\n        return DataLoader(\n            self.predict_dataset, batch_size=1, num_workers=self.num_workers\n        )\n\n    def get_train_positive_weights(self, widen_target_mask=3):\n        \"\"\"\n        Computes the relation of negative targets to positive targets.\n        `widen_target_mask` reduces the number of negative targets by the given\n        factor times the number of positive targets (for ignoring a number of\n        frames around each positive label).\n        For example a `widen_target_mask` of 3 will ignore 7 frames, 3 for each side plus the central.\n        \"\"\"\n        # find the positive weight for the loss as a ratio between (down)beat and non-(down)beat annotation\n        dataset = self.train_dataset\n        all_frames = all_frames_db = 0\n        for item in dataset.items:\n            frames = len(dataset._get_spect(item))\n            all_frames += frames\n            if item[\"downbeat_mask\"]:\n                all_frames_db += frames\n        beat_frames = sum(len(item[\"beat_value\"]) for item in dataset.items)\n        downbeat_frames = sum(\n            (item[\"beat_value\"] == 1).sum()\n            for item in dataset.items\n            if item[\"downbeat_mask\"]\n        )\n\n        return {\n            \"beat\": int(\n                np.round(\n                    (all_frames - beat_frames * (widen_target_mask * 2 + 1))\n                    / beat_frames\n                )\n            ),\n            \"downbeat\": int(\n                np.round(\n                    (all_frames_db - downbeat_frames * (widen_target_mask * 2 + 1))\n                    / downbeat_frames\n                )\n            ),\n        }\n\n\ndef prepare_annotations(item, start_frame, end_frame, fps):\n    truth_bdb_time = item[\"beat_time\"]\n    truth_bdb_value = item[\"beat_value\"]\n    # convert beat time from seconds to frame\n    truth_bdb_frame = (truth_bdb_time * fps).round().astype(int)\n    # form annotations excerpt\n    # filter out the annotations that are earlier than the start and shift left\n    truth_bdb_frame -= start_frame\n    idx = np.searchsorted(truth_bdb_frame, 0)\n    truth_bdb_frame = truth_bdb_frame[idx:]\n    truth_bdb_value = truth_bdb_value[idx:]\n    # filter out the annotations that are later than the end\n    idx = np.searchsorted(truth_bdb_frame, end_frame - start_frame)\n    truth_bdb_frame = truth_bdb_frame[:idx]\n    truth_bdb_value = truth_bdb_value[:idx]\n    # create beat and downbeat separated annotations\n    truth_beat = truth_bdb_frame\n    truth_downbeat = truth_bdb_frame[truth_bdb_value == 1]\n    # transform beat downbeat to frame-wise annotations\n    framewise_truth_beat = index_to_framewise(truth_beat, end_frame - start_frame)\n    framewise_truth_downbeat = index_to_framewise(\n        truth_downbeat, end_frame - start_frame\n    )\n    # create orig beat, downbeat annotations for unquantized evaluation\n    truth_orig_beat = item[\"beat_time\"]\n    truth_orig_downbeat = truth_bdb_time[\n        item[\"beat_value\"] == 1\n    ]  # (use the full beat_value)\n    # filter out the annotations that are outside the excerpt, and shift them left to the excerpt time\n    truth_orig_beat = truth_orig_beat[\n        (truth_orig_beat >= start_frame / fps) & (truth_orig_beat < end_frame / fps)\n    ] - (start_frame / fps)\n    truth_orig_downbeat = truth_orig_downbeat[\n        (truth_orig_downbeat >= start_frame / fps)\n        & (truth_orig_downbeat < end_frame / fps)\n    ] - (start_frame / fps)\n    # convert to strings (trick to collate sequences of different lengths)\n    truth_orig_beat = truth_orig_beat.tobytes()\n    truth_orig_downbeat = truth_orig_downbeat.tobytes()\n    return (\n        framewise_truth_beat,\n        framewise_truth_downbeat,\n        truth_orig_beat,\n        truth_orig_downbeat,\n    )\n"
  },
  {
    "path": "beat_this/dataset/mmnpz.py",
    "content": "\"\"\"\nSupport for memory-mapping uncompressed .npz files.\n\"\"\"\n\nimport struct\nfrom collections.abc import Mapping\nfrom zipfile import ZipFile\n\nimport numpy as np\n\n\nclass MemmappedNpzFile(Mapping):\n    \"\"\"\n    A dictionary-like object with lazy-loading of numpy arrays in the given\n    uncompressed .npz file. Upon construction, creates a memory map of the\n    full .npz file, returning views for the arrays within on request.\n\n    Attributes\n    ----------\n    files : list of str\n        List of all uncompressed files in the archive with a ``.npy`` extension\n        (listed without the extension). These are supported as dictionary keys.\n    mmap : np.memmap\n        The memory map of the full .npz file.\n    arrays : dict\n        Preloaded or cached arrays.\n\n    Parameters\n    ----------\n    fn : str or Path\n        The zipped archive to open.\n    cache : bool, optional\n        Whether to cache array objects in case they are requested again.\n    preload : bool, optional\n        Whether to precreate all array objects upon opening. Enforces caching.\n    \"\"\"\n\n    def __init__(self, fn: str, cache: bool = True, preload: bool = False):\n        with ZipFile(fn, mode=\"r\") as f:\n            self._offsets = {\n                zinfo.filename[:-4]: (zinfo.header_offset, zinfo.file_size)\n                for zinfo in f.infolist()\n                if zinfo.filename.endswith(\".npy\") and zinfo.compress_type == 0\n            }\n        self.files = list(self._offsets.keys())\n        self.mmap = np.memmap(fn, mode=\"r\")\n        self.cache = cache or preload\n        self.preload = preload\n        if self.preload:\n            self.arrays = {name: self.load(name) for name in self.files}\n        else:\n            self.arrays = {}\n\n    def load(self, name: str):\n        header_offset, file_size = self._offsets[name]\n        # parse lengths of local header file name and extra fields\n        # (ZipInfo is based on the global directory, not local header)\n        fn_len, extra_len = struct.unpack(\n            \"<2H\", self.mmap[header_offset + 26 : header_offset + 30]\n        )\n        # compute offset of start and end of data\n        npy_start = header_offset + 30 + fn_len + extra_len\n        npy_end = npy_start + file_size\n        # read NPY header\n        fp = MemoryviewIO(self.mmap)\n        fp.seek(npy_start)\n        version = np.lib.format.read_magic(fp)\n        np.lib.format._check_version(version)\n        shape, fortran, dtype = np.lib.format._read_array_header(fp, version)\n        # produce slice of memmap\n        data_start = fp.tell()\n        return (\n            self.mmap[data_start:npy_end]\n            .view(dtype=dtype)\n            .reshape(shape, order=\"F\" if fortran else \"C\")\n        )\n\n    def close(self):\n        if hasattr(self, \"mmap\"):\n            del self.mmap\n        self.arrays = {}\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        self.close()\n\n    def __iter__(self):\n        return iter(self.files)\n\n    def __len__(self):\n        return len(self.files)\n\n    def __getitem__(self, key: str):\n        if self.cache:\n            try:\n                return self.arrays[key]\n            except KeyError:\n                pass\n        array = self.load(key)\n        if self.cache:\n            self.arrays[key] = array\n        return array\n\n    def __contains__(self, key: str):\n        # Mapping.__contains__ calls __getitem__, which could be expensive\n        return key in self._offsets\n\n\nclass MemoryviewIO(object):\n    \"\"\"\n    Wraps an object supporting the buffer protocol to be a readonly file-like.\n    \"\"\"\n\n    def __init__(self, buffer):\n        self._buffer = memoryview(buffer).cast(\"B\")\n        self._pos = 0\n        self.seekable = lambda: True\n        self.readable = lambda: True\n        self.writable = lambda: False\n\n    def seek(self, offset, whence=0):\n        if whence == 0:\n            self._pos = offset\n        elif whence == 1:\n            self._pos += offset\n        elif whence == 2:\n            self._pos = self._buffer.nbytes + offset\n\n    def read(self, size=-1):\n        data = self._buffer[\n            self._pos : self._pos + size if size >= 0 else None\n        ].tobytes()\n        self._pos += len(data)\n        return data\n\n    def tell(self):\n        return self._pos\n"
  },
  {
    "path": "beat_this/inference.py",
    "content": "import inspect\n\nimport numpy as np\nimport soxr\nimport torch\nimport torch.nn.functional as F\n\nfrom beat_this.model.beat_tracker import BeatThis\nfrom beat_this.model.postprocessor import Postprocessor\nfrom beat_this.preprocessing import LogMelSpect, load_audio\nfrom beat_this.utils import replace_state_dict_key, save_beat_tsv\n\nCHECKPOINT_URL = \"https://cloud.cp.jku.at/public.php/dav/files/7ik4RrBKTS273gp\"\n\n\ndef load_checkpoint(checkpoint_path: str, device: str | torch.device = \"cpu\") -> dict:\n    \"\"\"\n    Load a BeatThis checkpoint as a dictionary.\n\n    Args:\n        checkpoint_path (str, optional): The path to the checkpoint. Can be a local path, a URL, or a shortname.\n        device (torch.device or str): The device to load the model on.\n\n    Returns:\n        dict: The loaded checkpoint dictionary.\n    \"\"\"\n    try:\n        # try interpreting as local file name\n        weights_only = {\"weights_only\": True} if torch.__version__ >= \"2\" else {}\n        return torch.load(checkpoint_path, map_location=device, **weights_only)\n    except FileNotFoundError:\n        try:\n            if not (\n                str(checkpoint_path).startswith(\"https://\")\n                or str(checkpoint_path).startswith(\"http://\")\n            ):\n                # interpret it as a name of one of our checkpoints\n                checkpoint_url = f\"{CHECKPOINT_URL}/{checkpoint_path}.ckpt\"\n                file_name = f\"beat_this-{checkpoint_path}.ckpt\"\n            else:\n                # try interpreting as a URL\n                checkpoint_url = checkpoint_path\n                file_name = None\n            return torch.hub.load_state_dict_from_url(\n                checkpoint_url,\n                file_name=file_name,\n                map_location=device,\n            )\n        except Exception:\n            raise ValueError(\n                \"Could not load the checkpoint given the provided name\",\n                checkpoint_path,\n            )\n\n\ndef load_model(\n    checkpoint_path: str | None = \"final0\", device: str | torch.device = \"cpu\"\n) -> BeatThis:\n    \"\"\"\n    Load a BeatThis model from a checkpoint.\n\n    Args:\n        checkpoint_path (str, optional): The path to the checkpoint. Can be a local path, a URL, or a shortname.\n        device (torch.device or str): The device to load the model on.\n\n    Returns:\n        BeatThis: The loaded model.\n    \"\"\"\n    if checkpoint_path is not None:\n        checkpoint = load_checkpoint(checkpoint_path, device)\n        # Retrieve the model hyperparameters as it could be the small model\n        hparams = checkpoint[\"hyper_parameters\"]\n        # Filter only those hyperparameters that apply to the model itself\n        hparams = {\n            k: v\n            for k, v in hparams.items()\n            if k in set(inspect.signature(BeatThis).parameters)\n        }\n        # Create the uninitialized model\n        model = BeatThis(**hparams)\n        # The PLBeatThis (LightningModule) state_dict contains the BeatThis\n        # state_dict under the \"model.\" prefix; remove the prefix to load it\n        state_dict = replace_state_dict_key(checkpoint[\"state_dict\"], \"model.\", \"\")\n        model.load_state_dict(state_dict)\n    else:\n        model = BeatThis()\n    return model.to(device).eval()\n\n\ndef zeropad(spect: torch.Tensor, left: int = 0, right: int = 0):\n    \"\"\"\n    Pads a tensor spectrogram matrix of shape (time x bins) with `left` frames in the beginning and `right` frames in the end.\n    \"\"\"\n    if left == 0 and right == 0:\n        return spect\n    else:\n        return F.pad(spect, (0, 0, left, right), \"constant\", 0)\n\n\ndef split_piece(\n    spect: torch.Tensor,\n    chunk_size: int,\n    border_size: int = 6,\n    avoid_short_end: bool = True,\n):\n    \"\"\"\n    Split a tensor spectrogram matrix of shape (time x bins) into time chunks of `chunk_size` and return the chunks and starting positions.\n    The `border_size` is the number of frames assumed to be discarded in the predictions on either side (since the model was not trained on the input edges due to the max-pool in the loss).\n    To cater for this, the first and last chunk are padded by `border_size` on the beginning and end, respectively, and consecutive chunks overlap by `border_size`.\n    If `avoid_short_end` is true, the last chunk start is shifted left to ends at the end of the piece, therefore the last chunk can potentially overlap with previous chunks more than border_size, otherwise it will be a shorter segment.\n    If the piece is shorter than `chunk_size`, avoid_short_end is ignored and the piece is returned as a single shorter chunk.\n\n    Args:\n        spect (torch.Tensor): The input spectrogram tensor of shape (time x bins).\n        chunk_size (int): The size of the chunks to produce.\n        border_size (int, optional): The size of the border to overlap between chunks. Defaults to 6.\n        avoid_short_end (bool, optional): If True, the last chunk is shifted left to end at the end of the piece. Defaults to True.\n    \"\"\"\n    # generate the start and end indices\n    starts = np.arange(\n        -border_size, len(spect) - border_size, chunk_size - 2 * border_size\n    )\n    if avoid_short_end and len(spect) > chunk_size - 2 * border_size:\n        # if we avoid short ends, move the last index to the end of the piece - (chunk_size - border_size)\n        starts[-1] = len(spect) - (chunk_size - border_size)\n    # generate the chunks\n    chunks = [\n        zeropad(\n            spect[max(start, 0) : min(start + chunk_size, len(spect))],\n            left=max(0, -start),\n            right=max(0, min(border_size, start + chunk_size - len(spect))),\n        )\n        for start in starts\n    ]\n    return chunks, starts\n\n\ndef aggregate_prediction(\n    pred_chunks: list,\n    starts: list,\n    full_size: int,\n    chunk_size: int,\n    border_size: int,\n    overlap_mode: str,\n    device: str | torch.device,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Aggregates the predictions for the whole piece based on the given prediction chunks.\n\n    Args:\n        pred_chunks (list): List of prediction chunks, where each chunk is a dictionary containing 'beat' and 'downbeat' predictions.\n        starts (list): List of start positions for each prediction chunk.\n        full_size (int): Size of the full piece.\n        chunk_size (int): Size of each prediction chunk.\n        border_size (int): Size of the border to be discarded from each prediction chunk.\n        overlap_mode (str): Mode for handling overlapping predictions. Can be 'keep_first' or 'keep_last'.\n        device (torch.device): Device to be used for the predictions.\n\n    Returns:\n        tuple: A tuple containing the aggregated beat predictions and downbeat predictions as torch tensors for the whole piece.\n    \"\"\"\n    if border_size > 0:\n        # cut the predictions to discard the border\n        pred_chunks = [\n            {\n                \"beat\": pchunk[\"beat\"][border_size:-border_size],\n                \"downbeat\": pchunk[\"downbeat\"][border_size:-border_size],\n            }\n            for pchunk in pred_chunks\n        ]\n    # aggregate the predictions for the whole piece\n    piece_prediction_beat = torch.full((full_size,), -1000.0, device=device)\n    piece_prediction_downbeat = torch.full((full_size,), -1000.0, device=device)\n    if overlap_mode == \"keep_first\":\n        # process in reverse order, so predictions of earlier excerpts overwrite later ones\n        pred_chunks = reversed(list(pred_chunks))\n        starts = reversed(list(starts))\n    for start, pchunk in zip(starts, pred_chunks):\n        piece_prediction_beat[\n            start + border_size : start + chunk_size - border_size\n        ] = pchunk[\"beat\"]\n        piece_prediction_downbeat[\n            start + border_size : start + chunk_size - border_size\n        ] = pchunk[\"downbeat\"]\n    return piece_prediction_beat, piece_prediction_downbeat\n\n\ndef split_predict_aggregate(\n    spect: torch.Tensor,\n    chunk_size: int,\n    border_size: int,\n    overlap_mode: str,\n    model: torch.nn.Module,\n) -> dict:\n    \"\"\"\n    Function for pieces that are longer than the training length of the model.\n    Split the input piece into chunks, run the model on them, and aggregate the predictions.\n    The spect is supposed to be a torch tensor of shape (time x bins), i.e., unbatched, and the output is also unbatched.\n\n    Args:\n        spect (torch.Tensor): the input piece\n        chunk_size (int): the length of the chunks\n        border_size (int): the size of the border that is discarded from the predictions\n        overlap_mode (str): how to handle overlaps between chunks\n        model (torch.nn.Module): the model to run\n\n    Returns:\n        dict: the model framewise predictions for the hole piece as a dictionary containing 'beat' and 'downbeat' predictions.\n    \"\"\"\n    # split the piece into chunks\n    chunks, starts = split_piece(\n        spect, chunk_size, border_size=border_size, avoid_short_end=True\n    )\n    # run the model\n    pred_chunks = [model(chunk.unsqueeze(0)) for chunk in chunks]\n    # remove the extra dimension in beat and downbeat prediction due to batch size 1\n    pred_chunks = [\n        {\"beat\": p[\"beat\"][0], \"downbeat\": p[\"downbeat\"][0]} for p in pred_chunks\n    ]\n    piece_prediction_beat, piece_prediction_downbeat = aggregate_prediction(\n        pred_chunks,\n        starts,\n        spect.shape[0],\n        chunk_size,\n        border_size,\n        overlap_mode,\n        spect.device,\n    )\n    # save it to model_prediction\n    return {\"beat\": piece_prediction_beat, \"downbeat\": piece_prediction_downbeat}\n\n\nclass Spect2Frames:\n    \"\"\"\n    Class for extracting framewise beat and downbeat predictions (logits) from a spectrogram.\n    \"\"\"\n\n    def __init__(self, checkpoint_path=\"final0\", device=\"cpu\", float16=False):\n        super().__init__()\n        self.device = torch.device(device)\n        self.float16 = float16\n        self.model = load_model(checkpoint_path, self.device)\n\n    def spect2frames(self, spect):\n        with torch.inference_mode():\n            with torch.autocast(enabled=self.float16, device_type=self.device.type):\n                model_prediction = split_predict_aggregate(\n                    spect=spect,\n                    chunk_size=1500,\n                    overlap_mode=\"keep_first\",\n                    border_size=6,\n                    model=self.model,\n                )\n        return model_prediction[\"beat\"].float(), model_prediction[\"downbeat\"].float()\n\n    def __call__(self, spect):\n        return self.spect2frames(spect)\n\n\nclass Audio2Frames(Spect2Frames):\n    \"\"\"\n    Class for extracting framewise beat and downbeat predictions (logits) from an audio tensor.\n    \"\"\"\n\n    def __init__(self, checkpoint_path=\"final0\", device=\"cpu\", float16=False):\n        super().__init__(checkpoint_path, device, float16)\n        self.spect = LogMelSpect(device=self.device)\n\n    def signal2spect(self, signal, sr):\n        if signal.ndim == 2:\n            signal = signal.mean(1)\n        elif signal.ndim != 1:\n            raise ValueError(f\"Expected 1D or 2D signal, got shape {signal.shape}\")\n        if sr != 22050:\n            signal = soxr.resample(signal, in_rate=sr, out_rate=22050)\n        signal = torch.tensor(signal, dtype=torch.float32, device=self.device)\n        return self.spect(signal)\n\n    def __call__(self, signal, sr):\n        spect = self.signal2spect(signal, sr)\n        return self.spect2frames(spect)\n\n\nclass Audio2Beats(Audio2Frames):\n    \"\"\"\n    Class for extracting beat and downbeat positions (in seconds) from an audio tensor.\n\n    Args:\n        checkpoint_path (str): Path to the model checkpoint file. It can be a local path, a URL, or a key from the CHECKPOINT_URL dictionary. Default is \"final0\", which will load the model trained on all data except GTZAN with seed 0.\n        device (str): Device to use for inference. Default is \"cpu\".\n        float16 (bool): Whether to use half precision floating point arithmetic. Default is False.\n        dbn (bool): Whether to use the madmom DBN for post-processing. Default is False.\n    \"\"\"\n\n    def __init__(\n        self, checkpoint_path=\"final0\", device=\"cpu\", float16=False, dbn=False\n    ):\n        super().__init__(checkpoint_path, device, float16)\n        self.frames2beats = Postprocessor(type=\"dbn\" if dbn else \"minimal\")\n\n    def __call__(self, signal, sr):\n        beat_logits, downbeat_logits = super().__call__(signal, sr)\n        return self.frames2beats(beat_logits, downbeat_logits)\n\n\nclass File2Beats(Audio2Beats):\n    def __call__(self, audio_path):\n        signal, sr = load_audio(audio_path)\n        return super().__call__(signal, sr)\n\n\nclass File2File(File2Beats):\n    def __call__(self, audio_path, output_path):\n        downbeats, beats = super().__call__(audio_path)\n        save_beat_tsv(downbeats, beats, output_path)\n"
  },
  {
    "path": "beat_this/model/__init__.py",
    "content": ""
  },
  {
    "path": "beat_this/model/beat_tracker.py",
    "content": "\"\"\"\nModel definitions for the Beat This! beat tracker.\n\"\"\"\n\nimport contextlib\nfrom collections import OrderedDict\n\nimport torch\nfrom einops import rearrange\nfrom einops.layers.torch import Rearrange\nfrom rotary_embedding_torch import RotaryEmbedding\nfrom torch import nn\n\nfrom beat_this.model import roformer\nfrom beat_this.utils import replace_state_dict_key\n\n\nclass BeatThis(nn.Module):\n    \"\"\"\n    A neural network model for beat tracking. It is composed of three main components:\n    - a frontend that processes the input spectrogram,\n    - a series of transformer blocks that process the output of the frontend,\n    - a head that produces the final beat and downbeat predictions.\n\n    Args:\n        spect_dim (int): The dimension of the input spectrogram (default: 128).\n        transformer_dim (int): The dimension of the main transformer blocks (default: 512).\n        ff_mult (int): The multiplier for the feed-forward dimension in the transformer blocks (default: 4).\n        n_layers (int): The number of transformer blocks (default: 6).\n        head_dim (int): The dimension of each attention head for the partial transformers in the frontend and the transformer blocks (default: 32).\n        stem_dim (int): The out dimension of the stem convolutional layer (default: 32).\n        dropout (dict): A dictionary specifying the dropout rates for different parts of the model\n            (default: {\"frontend\": 0.1, \"transformer\": 0.2}).\n        sum_head (bool): Whether to use a SumHead for the final predictions (default: True) or plain independent projections.\n        partial_transformers (bool): Whether to include partial frequency- and time-transformers in the frontend (default: True)\n    \"\"\"\n\n    def __init__(\n        self,\n        spect_dim: int = 128,\n        transformer_dim: int = 512,\n        ff_mult: int = 4,\n        n_layers: int = 6,\n        head_dim: int = 32,\n        stem_dim: int = 32,\n        dropout: dict = {\"frontend\": 0.1, \"transformer\": 0.2},\n        sum_head: bool = True,\n        partial_transformers: bool = True,\n    ):\n        super().__init__()\n        # shared rotary embedding for frontend blocks and transformer blocks\n        rotary_embed = RotaryEmbedding(head_dim)\n\n        # create the frontend\n        # - stem\n        stem = self.make_stem(spect_dim, stem_dim)\n        spect_dim //= 4  # frequencies were convolved with stride 4\n        # - three frontend blocks\n        frontend_blocks = []\n        dim = stem_dim\n        for _ in range(3):\n            frontend_blocks.append(\n                self.make_frontend_block(\n                    dim,\n                    dim * 2,\n                    partial_transformers,\n                    head_dim,\n                    rotary_embed,\n                    dropout[\"frontend\"],\n                )\n            )\n            dim *= 2\n            spect_dim //= 2  # frequencies were convolved with stride 2\n        frontend_blocks = nn.Sequential(*frontend_blocks)\n        # - linear projection to transformer dimensionality\n        concat = Rearrange(\"b c f t -> b t (c f)\")\n        linear = nn.Linear(dim * spect_dim, transformer_dim)\n        self.frontend = nn.Sequential(\n            OrderedDict(stem=stem, blocks=frontend_blocks, concat=concat, linear=linear)\n        )\n\n        # create the transformer blocks\n        assert (\n            transformer_dim % head_dim == 0\n        ), \"transformer_dim must be divisible by head_dim\"\n        n_heads = transformer_dim // head_dim\n        self.transformer_blocks = roformer.Transformer(\n            dim=transformer_dim,\n            depth=n_layers,\n            heads=n_heads,\n            attn_dropout=dropout[\"transformer\"],\n            ff_dropout=dropout[\"transformer\"],\n            rotary_embed=rotary_embed,\n            ff_mult=ff_mult,\n            dim_head=head_dim,\n            norm_output=True,\n        )\n\n        # create the output heads\n        if sum_head:\n            self.task_heads = SumHead(transformer_dim)\n        else:\n            self.task_heads = Head(transformer_dim)\n\n        # init all weights\n        self.apply(self._init_weights)\n\n    @staticmethod\n    def make_stem(spect_dim: int, stem_dim: int) -> nn.Module:\n        return nn.Sequential(\n            OrderedDict(\n                rearrange_tf=Rearrange(\"b t f -> b f t\"),\n                bn1d=nn.BatchNorm1d(spect_dim),\n                add_channel=Rearrange(\"b f t -> b 1 f t\"),\n                conv2d=nn.Conv2d(\n                    in_channels=1,\n                    out_channels=stem_dim,\n                    kernel_size=(4, 3),\n                    stride=(4, 1),\n                    padding=(0, 1),\n                    bias=False,\n                ),\n                bn2d=nn.BatchNorm2d(stem_dim),\n                activation=nn.GELU(),\n            )\n        )\n\n    @staticmethod\n    def make_frontend_block(\n        in_dim: int,\n        out_dim: int,\n        partial_transformers: bool = True,\n        head_dim: int | None = 32,\n        rotary_embed: RotaryEmbedding | None = None,\n        dropout: float = 0.1,\n    ) -> nn.Module:\n        if partial_transformers and (head_dim is None or rotary_embed is None):\n            raise ValueError(\n                \"Must specify head_dim and rotary_embed for using partial_transformers\"\n            )\n        return nn.Sequential(\n            OrderedDict(\n                partial=(\n                    PartialFTTransformer(\n                        dim=in_dim,\n                        dim_head=head_dim,\n                        n_head=in_dim // head_dim,\n                        rotary_embed=rotary_embed,\n                        dropout=dropout,\n                    )\n                    if partial_transformers\n                    else nn.Identity()\n                ),\n                # conv block\n                conv2d=nn.Conv2d(\n                    in_channels=in_dim,\n                    out_channels=out_dim,\n                    kernel_size=(2, 3),\n                    stride=(2, 1),\n                    padding=(0, 1),\n                    bias=False,\n                ),\n                # out_channels : 64, 128, 256\n                # freqs : 16, 8, 4 (due to the stride=2)\n                norm=nn.BatchNorm2d(out_dim),\n                activation=nn.GELU(),\n            )\n        )\n\n    @staticmethod\n    def _init_weights(module: nn.Module):\n        if isinstance(module, (nn.Linear, nn.Conv1d)):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Conv2d):\n            torch.nn.init.kaiming_normal_(\n                module.weight, mode=\"fan_out\", nonlinearity=\"relu\"\n            )\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n            if module.padding_idx is not None:\n                with torch.no_grad():\n                    module.weight[module.padding_idx].fill_(0)\n\n    def forward(self, x):\n        x = self.frontend(x)\n        x = self.transformer_blocks(x)\n        x = self.task_heads(x)\n        return x\n\n    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):\n        # remove _orig_mod prefixes for compiled models\n        state_dict = replace_state_dict_key(state_dict, \"_orig_mod.\", \"\")\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n    def state_dict(self, *args, **kwargs):\n        state_dict = super().state_dict(*args, **kwargs)\n        # remove _orig_mod prefixes for compiled models\n        state_dict = replace_state_dict_key(state_dict, \"_orig_mod.\", \"\")\n        return state_dict\n\n\nclass PartialRoformer(nn.Module):\n    \"\"\"\n    Takes a (batch, channels, freqs, time) input, applies self-attention and\n    a feed-forward block either only across frequencies or only across time.\n    Returns a tensor of the same shape as the input.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dim_head: int,\n        n_head: int,\n        direction: str,\n        rotary_embed: RotaryEmbedding,\n        dropout: float,\n    ):\n        super().__init__()\n\n        assert dim % dim_head == 0, \"dim must be divisible by dim_head\"\n        assert dim // dim_head == n_head, \"n_head must be equal to dim // dim_head\"\n        self.direction = direction[0].lower()\n        if self.direction not in \"ft\":\n            raise ValueError(f\"direction must be F or T, got {direction}\")\n        self.attn = roformer.Attention(\n            dim,\n            heads=n_head,\n            dim_head=dim_head,\n            dropout=dropout,\n            rotary_embed=rotary_embed,\n        )\n        self.ff = roformer.FeedForward(dim, dropout=dropout)\n\n    def forward(self, x):\n        b = len(x)\n        if self.direction == \"f\":\n            pattern = \"(b t) f c\"\n        elif self.direction == \"t\":\n            pattern = \"(b f) t c\"\n        x = rearrange(x, f\"b c f t -> {pattern}\")\n        x = x + self.attn(x)\n        x = x + self.ff(x)\n        x = rearrange(x, f\"{pattern} -> b c f t\", b=b)\n        return x\n\n\nclass PartialFTTransformer(nn.Module):\n    \"\"\"\n    Takes a (batch, channels, freqs, time) input, applies self-attention and\n    a feed-forward block once across frequencies and once across time. Same\n    as applying two PartialRoformer() in sequence, but encapsulated in a single\n    module. Returns a tensor of the same shape as the input.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        dim_head: int,\n        n_head: int,\n        rotary_embed: RotaryEmbedding,\n        dropout: float,\n    ):\n        super().__init__()\n\n        assert dim % dim_head == 0, \"dim must be divisible by dim_head\"\n        assert dim // dim_head == n_head, \"n_head must be equal to dim // dim_head\"\n        # frequency directed partial transformer\n        self.attnF = roformer.Attention(\n            dim,\n            heads=n_head,\n            dim_head=dim_head,\n            dropout=dropout,\n            rotary_embed=rotary_embed,\n        )\n        self.ffF = roformer.FeedForward(dim, dropout=dropout)\n        # time directed partial transformer\n        self.attnT = roformer.Attention(\n            dim,\n            heads=n_head,\n            dim_head=dim_head,\n            dropout=dropout,\n            rotary_embed=rotary_embed,\n        )\n        self.ffT = roformer.FeedForward(dim, dropout=dropout)\n\n    def forward(self, x):\n        b = len(x)\n        # frequency directed partial transformer\n        x = rearrange(x, \"b c f t -> (b t) f c\")\n        x = x + self.attnF(x)\n        x = x + self.ffF(x)\n        # time directed partial transformer\n        x = rearrange(x, \"(b t) f c ->(b f) t c\", b=b)\n        x = x + self.attnT(x)\n        x = x + self.ffT(x)\n        x = rearrange(x, \"(b f) t c -> b c f t\", b=b)\n        return x\n\n\nclass SumHead(nn.Module):\n    \"\"\"\n    A PyTorch module that produces the final beat and downbeat prediction logits.\n    The beats are a sum of all beats and all downbeats predictions, to reduce the prediction\n    of downbeats which are not beats.\n    \"\"\"\n\n    def __init__(self, input_dim):\n        super().__init__()\n        self.beat_downbeat_lin = nn.Linear(input_dim, 2)\n\n    def forward(self, x):\n        beat_downbeat = self.beat_downbeat_lin(x)\n        # separate beat from downbeat\n        beat, downbeat = rearrange(beat_downbeat, \"b t c -> c b t\", c=2)\n        # aggregate beats and downbeats prediction\n        # autocast to float16 disabled to avoid numerical issues causing NaNs\n        if hasattr(\n            torch.amp, \"is_autocast_available\"\n        ) and not torch.amp.is_autocast_available(beat.device.type):\n            # but do not try disabling if the device does not support autocast\n            disable_autocast = contextlib.nullcontext()\n        else:\n            disable_autocast = torch.autocast(beat.device.type, enabled=False)\n        with disable_autocast:\n            beat = beat.float() + downbeat.float()\n        return {\"beat\": beat, \"downbeat\": downbeat}\n\n\nclass Head(nn.Module):\n    \"\"\"\n    A PyToch module that produces the final beat and downbeat prediction logits with independent linear layers outputs.\n    \"\"\"\n\n    def __init__(self, input_dim):\n        super().__init__()\n        self.beat_downbeat_lin = nn.Linear(input_dim, 2)\n\n    def forward(self, x):\n        beat_downbeat = self.beat_downbeat_lin(x)\n        # separate beat from downbeat\n        beat, downbeat = rearrange(beat_downbeat, \"b t c -> c b t\", c=2)\n        return {\"beat\": beat, \"downbeat\": downbeat}\n"
  },
  {
    "path": "beat_this/model/loss.py",
    "content": "\"\"\"\nLoss definitions for the Beat This! beat tracker.\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n\nclass MaskedBCELoss(torch.nn.Module):\n    \"\"\"\n    Plain binary cross-entropy loss. Expects predictions to be given as logits,\n    and accepts an optional mask with zeros indicating the entries to ignore.\n\n    Args:\n        pos_weight (float): Weight for positive examples compared to negative\n            examples (default: 1)\n    \"\"\"\n\n    def __init__(self, pos_weight: float = 1):\n        super().__init__()\n        self.register_buffer(\n            \"pos_weight\",\n            torch.tensor(pos_weight, dtype=torch.get_default_dtype()),\n            persistent=False,\n        )\n\n    def forward(\n        self,\n        preds: torch.Tensor,\n        targets: torch.Tensor,\n        mask: torch.Tensor | None = None,\n    ):\n        return F.binary_cross_entropy_with_logits(\n            preds, targets, weight=mask, pos_weight=self.pos_weight\n        )\n\n\nclass ShiftTolerantBCELoss(torch.nn.Module):\n    \"\"\"\n    BCE loss variant for sequence labeling that tolerates small shifts between\n    predictions and targets. This is accomplished by max-pooling the\n    predictions with a given tolerance and a stride of 1, so the gradient for a\n    positive label affects the largest prediction in a window around it.\n    Expects predictions to be given as logits, and accepts an optional mask\n    with zeros indicating the entries to ignore. Note that the edges of the\n    sequence will not receive a gradient, as it is assumed to be unknown\n    whether there is a nearby positive annotation.\n\n    Args:\n        pos_weight (float): Weight for positive examples compared to negative\n            examples (default: 1)\n        tolerance (int): Tolerated shift in time steps in each direction\n            (default: 3)\n    \"\"\"\n\n    def __init__(self, pos_weight: float = 1, tolerance: int = 3):\n        super().__init__()\n        self.register_buffer(\n            \"pos_weight\",\n            torch.tensor(pos_weight, dtype=torch.get_default_dtype()),\n            persistent=False,\n        )\n        self.tolerance = tolerance\n\n    def spread(self, x: torch.Tensor, factor: int = 1):\n        if self.tolerance == 0:\n            return x\n        return F.max_pool1d(x, 1 + 2 * factor * self.tolerance, 1)\n\n    def crop(self, x: torch.Tensor, factor: int = 1):\n        return x[..., factor * self.tolerance : -factor * self.tolerance or None]\n\n    def forward(\n        self,\n        preds: torch.Tensor,\n        targets: torch.Tensor,\n        mask: torch.Tensor | None = None,\n    ):\n        # spread preds and crop targets to match\n        spreaded_preds = self.crop(self.spread(preds))\n        cropped_targets = self.crop(targets, factor=2)\n        # ignore around the positive targets\n        look_at = cropped_targets + (1 - self.spread(targets, factor=2))\n        if mask is not None:  # consider padding and no-downbeat mask\n            look_at = look_at * self.crop(mask, factor=2)\n        # compute loss\n        return F.binary_cross_entropy_with_logits(\n            spreaded_preds,\n            cropped_targets,\n            weight=look_at,\n            pos_weight=self.pos_weight,\n        )\n\n\nclass SplittedShiftTolerantBCELoss(torch.nn.Module):\n    \"\"\"\n    Alternative implementation of ShiftTolerantBCELoss that splits the loss for\n    positive and negative targets. This is mainly provided as it may be a bit\n    easier to understand and compare with the Beat This! paper. Note that for\n    non-binary targets (e.g., with label smoothing), this implementation\n    matches the equation in the paper (Section 3.3), while ShiftTolerantBCELoss\n    deviates from it. For binary targets, the results are identical.\n\n    Args:\n        pos_weight (int): weight of positive targets\n        spread_preds (int): amount of temporal max-pooling applied to predictions\n    \"\"\"\n\n    def __init__(self, pos_weight: float = 1, tolerance: int = 3):\n        super().__init__()\n        self.tolerance = 3\n        self.spread_preds = tolerance\n        self.spread_targets = 2 * tolerance  # targets are always spreaded twice as much\n        self.register_buffer(\n            \"pos_weight\",\n            torch.tensor(pos_weight, dtype=torch.get_default_dtype()),\n            persistent=False,\n        )\n\n    def spread(self, x: torch.Tensor, amount: int):\n        if amount:\n            return F.max_pool1d(x, 1 + 2 * amount, 1)\n        else:\n            return x\n\n    def crop(self, x: torch.Tensor, desired_length: int):\n        amount = (x.shape[-1] - desired_length) // 2\n        if amount > 0:\n            return x[..., amount:-amount]\n        elif amount == 0:\n            return x\n        else:\n            raise ValueError(\"Desired length must be smaller than input length\")\n\n    def forward(self, preds: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):\n        output_length = targets.size(-1) - 2 * self.spread_targets\n        # compute loss for positive targets, we spread preds\n        preds = self.spread(preds, self.spread_preds)\n        # we crop preds and targets (and mask) to ignore problems at the edges due to the maxpool operation\n        cropped_preds = self.crop(preds, output_length)\n        cropped_targets = self.crop(targets, output_length)\n        cropped_mask = self.crop(mask, output_length)\n        loss_positive = F.binary_cross_entropy_with_logits(\n            cropped_preds,\n            cropped_targets,\n            weight=cropped_targets * cropped_mask,\n            pos_weight=self.pos_weight,\n        )\n\n        # compute loss for negative targets, we spread targets and preds (already spreaded above)\n        targets = self.spread(targets, self.spread_targets)\n        cropped_targets = self.crop(targets, output_length)\n        loss_negative = F.binary_cross_entropy_with_logits(\n            cropped_preds,\n            cropped_targets,\n            weight=(1 - cropped_targets) * cropped_mask,\n            pos_weight=self.pos_weight,  # ensures identical results to the other implementation\n        )\n        # sum the two losses\n        return loss_positive + loss_negative\n"
  },
  {
    "path": "beat_this/model/pl_module.py",
    "content": "\"\"\"\nPytorch Lightning module, wraps a BeatThis model along with losses, metrics and\noptimizers for training.\n\"\"\"\n\nfrom concurrent.futures import ThreadPoolExecutor\nfrom typing import Any\n\nimport mir_eval\nimport numpy as np\nimport torch\nfrom pytorch_lightning import LightningModule\n\nimport beat_this.model.loss\nfrom beat_this.inference import split_predict_aggregate\nfrom beat_this.model.beat_tracker import BeatThis\nfrom beat_this.model.postprocessor import Postprocessor\nfrom beat_this.utils import replace_state_dict_key\n\n\nclass PLBeatThis(LightningModule):\n    def __init__(\n        self,\n        spect_dim=128,\n        fps=50,\n        transformer_dim=512,\n        ff_mult=4,\n        n_layers=6,\n        stem_dim=32,\n        dropout={\"frontend\": 0.1, \"transformer\": 0.2},\n        lr=0.0008,\n        weight_decay=0.01,\n        pos_weights={\"beat\": 1, \"downbeat\": 1},\n        head_dim=32,\n        loss_type=\"shift_tolerant_weighted_bce\",\n        warmup_steps=1000,\n        max_epochs=100,\n        use_dbn=False,\n        eval_trim_beats=5,\n        sum_head=True,\n        partial_transformers=True,\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        self.lr = lr\n        self.weight_decay = weight_decay\n        self.fps = fps\n        # create model\n        self.model = BeatThis(\n            spect_dim=spect_dim,\n            transformer_dim=transformer_dim,\n            ff_mult=ff_mult,\n            stem_dim=stem_dim,\n            n_layers=n_layers,\n            head_dim=head_dim,\n            dropout=dropout,\n            sum_head=sum_head,\n            partial_transformers=partial_transformers,\n        )\n        self.warmup_steps = warmup_steps\n        self.max_epochs = max_epochs\n        # set up the losses\n        self.pos_weights = pos_weights\n        if loss_type == \"shift_tolerant_weighted_bce\":\n            self.beat_loss = beat_this.model.loss.ShiftTolerantBCELoss(\n                pos_weight=pos_weights[\"beat\"]\n            )\n            self.downbeat_loss = beat_this.model.loss.ShiftTolerantBCELoss(\n                pos_weight=pos_weights[\"downbeat\"]\n            )\n        elif loss_type == \"weighted_bce\":\n            self.beat_loss = beat_this.model.loss.MaskedBCELoss(\n                pos_weight=pos_weights[\"beat\"]\n            )\n            self.downbeat_loss = beat_this.model.loss.MaskedBCELoss(\n                pos_weight=pos_weights[\"downbeat\"]\n            )\n        elif loss_type == \"bce\":\n            self.beat_loss = beat_this.model.loss.MaskedBCELoss()\n            self.downbeat_loss = beat_this.model.loss.MaskedBCELoss()\n        elif loss_type == \"splitted_shift_tolerant_weighted_bce\":\n            self.beat_loss = beat_this.model.loss.SplittedShiftTolerantBCELoss(\n                pos_weight=pos_weights[\"beat\"]\n            )\n            self.downbeat_loss = beat_this.model.loss.SplittedShiftTolerantBCELoss(\n                pos_weight=pos_weights[\"downbeat\"]\n            )\n        else:\n            raise ValueError(\n                \"loss_type must be one of 'shift_tolerant_weighted_bce', 'weighted_bce', 'bce'\"\n            )\n\n        self.postprocessor = Postprocessor(\n            type=\"dbn\" if use_dbn else \"minimal\", fps=fps\n        )\n        self.eval_trim_beats = eval_trim_beats\n        self.metrics = Metrics(eval_trim_beats=eval_trim_beats)\n\n    def _compute_loss(self, batch, model_prediction):\n        beat_mask = batch[\"padding_mask\"]\n        beat_loss = self.beat_loss(\n            model_prediction[\"beat\"], batch[\"truth_beat\"].float(), beat_mask\n        )\n        # downbeat mask considers padding and also pieces which don't have downbeat annotations\n        downbeat_mask = beat_mask * batch[\"downbeat_mask\"][:, None]\n        downbeat_loss = self.downbeat_loss(\n            model_prediction[\"downbeat\"], batch[\"truth_downbeat\"].float(), downbeat_mask\n        )\n        # sum the losses and return them in a dictionary for logging\n        return {\n            \"beat\": beat_loss,\n            \"downbeat\": downbeat_loss,\n            \"total\": beat_loss + downbeat_loss,\n        }\n\n    def _compute_metrics(self, batch, postp_beat, postp_downbeat, step=\"val\"):\n        \"\"\" \"\"\"\n        # compute for beat\n        metrics_beat = self._compute_metrics_target(\n            batch, postp_beat, target=\"beat\", step=step\n        )\n        # compute for downbeat\n        metrics_downbeat = self._compute_metrics_target(\n            batch, postp_downbeat, target=\"downbeat\", step=step\n        )\n\n        # concatenate dictionaries\n        metrics = {**metrics_beat, **metrics_downbeat}\n\n        return metrics\n\n    def _compute_metrics_target(self, batch, postp_target, target, step):\n\n        def compute_item(pospt_pred, truth_orig_target):\n            # take the ground truth from the original version, so there are no quantization errors\n            piece_truth_time = np.frombuffer(truth_orig_target)\n            # run evaluation\n            metrics = self.metrics(piece_truth_time, pospt_pred, step=step)\n\n            return metrics\n\n        # if the input was not batched, postp_target is an array instead of a tuple of arrays\n        # make it a tuple for consistency\n        if not isinstance(postp_target, tuple):\n            postp_target = (postp_target,)\n\n        with ThreadPoolExecutor() as executor:\n            piecewise_metrics = list(\n                executor.map(\n                    compute_item,\n                    postp_target,\n                    batch[f\"truth_orig_{target}\"],\n                )\n            )\n\n        # average the beat metrics across the dictionary\n        batch_metric = {\n            key + f\"_{target}\": np.mean([x[key] for x in piecewise_metrics])\n            for key in piecewise_metrics[0].keys()\n        }\n\n        return batch_metric\n\n    def log_losses(self, losses, batch_size, step=\"train\"):\n        # log for separate targets\n        for target in \"beat\", \"downbeat\":\n            self.log(\n                f\"{step}_loss_{target}\",\n                losses[target].item(),\n                prog_bar=False,\n                on_step=False,\n                on_epoch=True,\n                batch_size=batch_size,\n                sync_dist=True,\n            )\n        # log total loss\n        self.log(\n            f\"{step}_loss\",\n            losses[\"total\"].item(),\n            prog_bar=True,\n            on_step=False,\n            on_epoch=True,\n            batch_size=batch_size,\n            sync_dist=True,\n        )\n\n    def log_metrics(self, metrics, batch_size, step=\"val\"):\n        for key, value in metrics.items():\n            self.log(\n                f\"{step}_{key}\",\n                value,\n                prog_bar=key.startswith(\"F-measure\"),\n                on_step=False,\n                on_epoch=True,\n                batch_size=batch_size,\n                sync_dist=True,\n            )\n\n    def training_step(self, batch, batch_idx):\n        # run the model\n        model_prediction = self.model(batch[\"spect\"])\n        # compute loss\n        losses = self._compute_loss(batch, model_prediction)\n        self.log_losses(losses, len(batch[\"spect\"]), \"train\")\n        return losses[\"total\"]\n\n    def validation_step(self, batch, batch_idx):\n        # run the model\n        model_prediction = self.model(batch[\"spect\"])\n        # compute loss\n        losses = self._compute_loss(batch, model_prediction)\n        # postprocess the predictions\n        postp_beat, postp_downbeat = self.postprocessor(\n            model_prediction[\"beat\"],\n            model_prediction[\"downbeat\"],\n            batch[\"padding_mask\"],\n        )\n        # compute the metrics\n        metrics = self._compute_metrics(batch, postp_beat, postp_downbeat, step=\"val\")\n        # log\n        self.log_losses(losses, len(batch[\"spect\"]), \"val\")\n        self.log_metrics(metrics, batch[\"spect\"].shape[0], \"val\")\n\n    def test_step(self, batch, batch_idx):\n        metrics, model_prediction, _, _ = self.predict_step(batch, batch_idx)\n        losses = self._compute_loss(batch, model_prediction)\n        # log\n        self.log_losses(losses, len(batch[\"spect\"]), \"test\")\n        self.log_metrics(metrics, batch[\"spect\"].shape[0], \"test\")\n\n    def predict_step(\n        self,\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n        chunk_size: int = 1500,\n        overlap_mode: str = \"keep_first\",\n    ) -> Any:\n        \"\"\"\n        Compute predictions and metrics for a batch (a dictionary with an \"spect\" key).\n        It splits up the audio into multiple chunks of chunk size,\n         which should correspond to the length of the sequence the model was trained with.\n        Potential overlaps between chunks can be handled in two ways:\n        by keeping the predictions of the excerpt coming first (overlap_mode='keep_first'), or\n        by keeping the predictions of the excerpt coming last (overlap_mode='keep_last').\n        Note that overlaps appear as the last excerpt is moved backwards\n        when it would extend over the end of the piece.\n        \"\"\"\n        if batch[\"spect\"].shape[0] != 1:\n            raise ValueError(\n                \"When predicting full pieces, only `batch_size=1` is supported\"\n            )\n        if torch.any(~batch[\"padding_mask\"]):\n            raise ValueError(\n                \"When predicting full pieces, the Dataset must not pad inputs\"\n            )\n        # compute border size according to the loss type\n        if hasattr(\n            self.beat_loss, \"tolerance\"\n        ):  # discard the edges that are affected by the max-pooling in the loss\n            border_size = 2 * self.beat_loss.tolerance\n        else:\n            border_size = 0\n        model_prediction = split_predict_aggregate(\n            batch[\"spect\"][0], chunk_size, border_size, overlap_mode, self.model\n        )\n        # add the batch dimension back in the prediction for consistency\n        model_prediction = {\n            key: value.unsqueeze(0) for key, value in model_prediction.items()\n        }\n        # postprocess the predictions\n        postp_beat, postp_downbeat = self.postprocessor(\n            model_prediction[\"beat\"], model_prediction[\"downbeat\"], None\n        )\n        # compute the metrics\n        metrics = self._compute_metrics(batch, postp_beat, postp_downbeat, step=\"test\")\n        return metrics, model_prediction, batch[\"dataset\"], batch[\"spect_path\"]\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.AdamW\n        # only decay 2+-dimensional tensors, to exclude biases and norms\n        # (filtering on dimensionality idea taken from Kaparthy's nano-GPT)\n        params = [\n            {\n                \"params\": (\n                    p for p in self.parameters() if p.requires_grad and p.ndim >= 2\n                ),\n                \"weight_decay\": self.weight_decay,\n            },\n            {\n                \"params\": (\n                    p for p in self.parameters() if p.requires_grad and p.ndim <= 1\n                ),\n                \"weight_decay\": 0,\n            },\n        ]\n\n        optimizer = optimizer(params, lr=self.lr)\n\n        self.lr_scheduler = CosineWarmupScheduler(\n            optimizer, self.warmup_steps, self.trainer.estimated_stepping_batches\n        )\n\n        result = dict(optimizer=optimizer)\n        result[\"lr_scheduler\"] = {\"scheduler\": self.lr_scheduler, \"interval\": \"step\"}\n        return result\n\n    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):\n        # remove _orig_mod prefixes for compiled models\n        state_dict = replace_state_dict_key(state_dict, \"_orig_mod.\", \"\")\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n    def state_dict(self, *args, **kwargs):\n        state_dict = super().state_dict(*args, **kwargs)\n        # remove _orig_mod prefixes for compiled models\n        state_dict = replace_state_dict_key(state_dict, \"_orig_mod.\", \"\")\n        return state_dict\n\n\nclass Metrics:\n    def __init__(self, eval_trim_beats: int) -> None:\n        self.min_beat_time = eval_trim_beats\n\n    def __call__(self, truth, preds, step) -> Any:\n        truth = mir_eval.beat.trim_beats(truth, min_beat_time=self.min_beat_time)\n        preds = mir_eval.beat.trim_beats(preds, min_beat_time=self.min_beat_time)\n        if (\n            step == \"val\"\n        ):  # limit the metrics that are computed during validation to speed up training\n            fmeasure = mir_eval.beat.f_measure(truth, preds)\n            cemgil = mir_eval.beat.cemgil(truth, preds)\n            return {\"F-measure\": fmeasure, \"Cemgil\": cemgil}\n        elif step == \"test\":  # compute all metrics during testing\n            CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(truth, preds)\n            fmeasure = mir_eval.beat.f_measure(truth, preds)\n            cemgil = mir_eval.beat.cemgil(truth, preds)\n            return {\"F-measure\": fmeasure, \"Cemgil\": cemgil, \"CMLt\": CMLt, \"AMLt\": AMLt}\n        else:\n            raise ValueError(\"step must be either val or test\")\n\n\nclass CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):\n    \"\"\"\n    Cosine annealing over `max_iters` steps with `warmup` linear warmup steps.\n    Optionally re-raises the learning rate for the final `raise_last` fraction\n    of total training time to `raise_to` of the full learning rate, again with\n    a linear warmup (useful for stochastic weight averaging).\n    \"\"\"\n\n    def __init__(self, optimizer, warmup, max_iters, raise_last=0, raise_to=0.5):\n        self.warmup = warmup\n        self.max_num_iters = int((1 - raise_last) * max_iters)\n        self.raise_to = raise_to\n        super().__init__(optimizer)\n\n    def get_lr(self):\n        lr_factor = self.get_lr_factor(step=self.last_epoch)\n        return [base_lr * lr_factor for base_lr in self.base_lrs]\n\n    def get_lr_factor(self, step):\n        if step < self.max_num_iters:\n            progress = step / self.max_num_iters\n            lr_factor = 0.5 * (1 + np.cos(np.pi * progress))\n            if step <= self.warmup:\n                lr_factor *= step / self.warmup\n        else:\n            progress = (step - self.max_num_iters) / self.warmup\n            lr_factor = self.raise_to * min(progress, 1)\n        return lr_factor\n"
  },
  {
    "path": "beat_this/model/postprocessor.py",
    "content": "from concurrent.futures import ThreadPoolExecutor\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\nclass Postprocessor:\n    \"\"\"Postprocessor for the beat and downbeat predictions of the model.\n    The postprocessor takes the (framewise) model predictions (beat and downbeats) and the padding mask,\n    and returns the postprocessed beat and downbeat as list of times in seconds.\n    The beats and downbeats can be 1D arrays (for only 1 piece) or 2D arrays, if a batch of pieces is considered.\n    The output dimensionality is the same as the input dimensionality.\n    Two types of postprocessing are implemented:\n        - minimal: a simple postprocessing that takes the maximum of the framewise predictions,\n        and removes adjacent peaks.\n        - dbn: a postprocessing based on the Dynamic Bayesian Network proposed by Böck et al.\n    Args:\n        type (str): the type of postprocessing to apply. Either \"minimal\" or \"dbn\". Default is \"minimal\".\n        fps (int): the frames per second of the model framewise predictions. Default is 50.\n    \"\"\"\n\n    def __init__(self, type: str = \"minimal\", fps: int = 50):\n        assert type in [\"minimal\", \"dbn\"]\n        self.type = type\n        self.fps = fps\n        if type == \"dbn\":\n            from madmom.features.downbeats import DBNDownBeatTrackingProcessor\n\n            self.dbn = DBNDownBeatTrackingProcessor(\n                beats_per_bar=[3, 4],\n                min_bpm=55.0,\n                max_bpm=215.0,\n                fps=self.fps,\n                transition_lambda=100,\n            )\n\n    def __call__(\n        self,\n        beat: torch.Tensor,\n        downbeat: torch.Tensor,\n        padding_mask: torch.Tensor | None = None,\n    ) -> tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Apply postprocessing to the input beat and downbeat tensors. Works with batched and unbatched inputs.\n        The output is a list of times in seconds, or a list of lists of times in seconds, if the input is batched.\n\n        Args:\n            beat (torch.Tensor): The input beat tensor.\n            downbeat (torch.Tensor): The input downbeat tensor.\n            padding_mask (torch.Tensor, optional): The padding mask tensor. Defaults to None.\n\n        Returns:\n            torch.Tensor: The postprocessed beat tensor.\n            torch.Tensor: The postprocessed downbeat tensor.\n        \"\"\"\n        batched = False if beat.ndim == 1 else True\n        if padding_mask is None:\n            padding_mask = torch.ones_like(beat, dtype=torch.bool)\n\n        # if beat and downbeat are 1D tensors, add a batch dimension\n        if not batched:\n            beat = beat.unsqueeze(0)\n            downbeat = downbeat.unsqueeze(0)\n            padding_mask = padding_mask.unsqueeze(0)\n\n        if self.type == \"minimal\":\n            postp_beat, postp_downbeat = self.postp_minimal(\n                beat, downbeat, padding_mask\n            )\n        elif self.type == \"dbn\":\n            postp_beat, postp_downbeat = self.postp_dbn(beat, downbeat, padding_mask)\n        else:\n            raise ValueError(\"Invalid postprocessing type\")\n\n        # remove the batch dimension if it was added\n        if not batched:\n            postp_beat = postp_beat[0]\n            postp_downbeat = postp_downbeat[0]\n\n        # update the model prediction dict\n        return postp_beat, postp_downbeat\n\n    def postp_minimal(self, beat, downbeat, padding_mask):\n        # concatenate beat and downbeat in the same tensor of shape (B, T, 2)\n        packed_pred = rearrange(\n            [beat, downbeat], \"c b t -> b t c\", b=beat.shape[0], t=beat.shape[1], c=2\n        )\n        # set padded elements to -1000 (= probability zero even in float64) so they don't influence the maxpool\n        pred_logits = packed_pred.masked_fill(~padding_mask.unsqueeze(-1), -1000)\n        # reshape to (2*B, T) to apply max pooling\n        pred_logits = rearrange(pred_logits, \"b t c -> (c b) t\")\n        # pick maxima within +/- 70ms\n        pred_peaks = pred_logits.masked_fill(\n            pred_logits != F.max_pool1d(pred_logits, 7, 1, 3), -1000\n        )\n        # keep maxima with over 0.5 probability (logit > 0)\n        pred_peaks = pred_peaks > 0\n        #  rearrange back to two tensors of shape (B, T)\n        beat_peaks, downbeat_peaks = rearrange(\n            pred_peaks, \"(c b) t -> c b t\", b=beat.shape[0], t=beat.shape[1], c=2\n        )\n        # run the piecewise operations\n        with ThreadPoolExecutor() as executor:\n            postp_beat, postp_downbeat = zip(\n                *executor.map(\n                    self._postp_minimal_item, beat_peaks, downbeat_peaks, padding_mask\n                )\n            )\n        return postp_beat, postp_downbeat\n\n    def _postp_minimal_item(self, padded_beat_peaks, padded_downbeat_peaks, mask):\n        \"\"\"Function to compute the operations that must be computed piece by piece, and cannot be done in batch.\"\"\"\n        # unpad the predictions by truncating the padding positions\n        beat_peaks = padded_beat_peaks[mask]\n        downbeat_peaks = padded_downbeat_peaks[mask]\n        # pass from a boolean array to a list of times in frames.\n        beat_frame = torch.nonzero(beat_peaks).cpu().numpy()[:, 0]\n        downbeat_frame = torch.nonzero(downbeat_peaks).cpu().numpy()[:, 0]\n        # remove adjacent peaks\n        beat_frame = deduplicate_peaks(beat_frame, width=1)\n        downbeat_frame = deduplicate_peaks(downbeat_frame, width=1)\n        # convert from frame to seconds\n        beat_time = beat_frame / self.fps\n        downbeat_time = downbeat_frame / self.fps\n        # move the downbeat to the nearest beat\n        if (\n            len(beat_time) > 0\n        ):  # skip if there are no beats, like in the first training steps\n            for i, d_time in enumerate(downbeat_time):\n                beat_idx = np.argmin(np.abs(beat_time - d_time))\n                downbeat_time[i] = beat_time[beat_idx]\n        # remove duplicate downbeat times (if some db were moved to the same position)\n        downbeat_time = np.unique(downbeat_time)\n        return beat_time, downbeat_time\n\n    def postp_dbn(self, beat, downbeat, padding_mask):\n        beat_prob = beat.double().sigmoid()\n        downbeat_prob = downbeat.double().sigmoid()\n        # limit lower and upper bound, since 0 and 1 create problems in the DBN\n        epsilon = 1e-5\n        beat_prob = beat_prob * (1 - epsilon) + epsilon / 2\n        downbeat_prob = downbeat_prob * (1 - epsilon) + epsilon / 2\n        with ThreadPoolExecutor() as executor:\n            postp_beat, postp_downbeat = zip(\n                *executor.map(\n                    self._postp_dbn_item, beat_prob, downbeat_prob, padding_mask\n                )\n            )\n        return postp_beat, postp_downbeat\n\n    def _postp_dbn_item(self, padded_beat_prob, padded_downbeat_prob, mask):\n        \"\"\"Function to compute the operations that must be computed piece by piece, and cannot be done in batch.\"\"\"\n        # unpad the predictions by truncating the padding positions\n        beat_prob = padded_beat_prob[mask]\n        downbeat_prob = padded_downbeat_prob[mask]\n        # build an artificial multiclass prediction, as suggested by Böck et al.\n        # again we limit the lower bound to avoid problems with the DBN\n        epsilon = 1e-5\n        combined_act = np.vstack(\n            (\n                np.maximum(\n                    beat_prob.cpu().numpy() - downbeat_prob.cpu().numpy(), epsilon / 2\n                ),\n                downbeat_prob.cpu().numpy(),\n            )\n        ).T\n        # run the DBN\n        dbn_out = self.dbn(combined_act)\n        postp_beat = dbn_out[:, 0]\n        postp_downbeat = dbn_out[dbn_out[:, 1] == 1][:, 0]\n        return postp_beat, postp_downbeat\n\n\ndef deduplicate_peaks(peaks, width=1) -> np.ndarray:\n    \"\"\"\n    Replaces groups of adjacent peak frame indices that are each not more\n    than `width` frames apart by the average of the frame indices.\n    \"\"\"\n    result = []\n    peaks = map(int, peaks)  # ensure we get ordinary Python int objects\n    try:\n        p = next(peaks)\n    except StopIteration:\n        return np.array(result)\n    c = 1\n    for p2 in peaks:\n        if p2 - p <= width:\n            c += 1\n            p += (p2 - p) / c  # update mean\n        else:\n            result.append(p)\n            p = p2\n            c = 1\n    result.append(p)\n    return np.array(result)\n"
  },
  {
    "path": "beat_this/model/roformer.py",
    "content": "\"\"\"\nTransformer with rotary position embedding, adapted from Phil Wang's repository\nat https://github.com/lucidrains/BS-RoFormer (under MIT License).\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom torch import nn\nfrom torch.nn import Module, ModuleList\n\n# helper functions\n\n\ndef exists(val):\n    return val is not None\n\n\n# norm\n\n\nclass RMSNorm(Module):\n    def __init__(self, size, dim=-1):\n        super().__init__()\n        self.scale = size**0.5\n        if dim >= 0:\n            raise ValueError(f\"dim must be negative, got {dim}\")\n        self.gamma = nn.Parameter(torch.ones((size,) + (1,) * (abs(dim) - 1)))\n        self.dim = dim\n\n    def forward(self, x):\n        return F.normalize(x, dim=self.dim) * self.scale * self.gamma\n\n\n# feedforward\n\n\nclass FeedForward(Module):\n    def __init__(\n        self,\n        dim,\n        mult=4,\n        dropout=0.0,\n        dim_out=None,\n    ):\n        super().__init__()\n        if dim_out is None:\n            dim_out = dim\n        dim_inner = int(dim * mult)\n        self.activation = nn.GELU()\n        self.net = nn.Sequential(\n            RMSNorm(dim),\n            nn.Linear(dim, dim_inner),\n            self.activation,\n            nn.Dropout(dropout),\n            nn.Linear(dim_inner, dim_out),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\n# attention\n\n\nclass Attend(nn.Module):\n    def __init__(self, dropout=0.0, scale=None):\n        super().__init__()\n        self.dropout = dropout\n        self.scale = scale\n\n    def forward(self, q, k, v):\n        if exists(self.scale):\n            default_scale = q.shape[-1] ** -0.5\n            q = q * (self.scale / default_scale)\n\n        return F.scaled_dot_product_attention(\n            q, k, v, dropout_p=self.dropout if self.training else 0.0\n        )\n\n\nclass Attention(Module):\n    def __init__(\n        self,\n        dim,\n        heads=8,\n        dim_head=64,\n        dropout=0.0,\n        rotary_embed=None,\n        gating=True,\n    ):\n        super().__init__()\n        self.heads = heads\n        self.scale = dim_head**-0.5\n        dim_inner = heads * dim_head\n\n        self.rotary_embed = rotary_embed\n\n        self.attend = Attend(dropout=dropout)\n\n        self.norm = RMSNorm(dim)\n        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)\n\n        if gating:\n            self.to_gates = nn.Linear(dim, heads)\n        else:\n            self.to_gates = None\n\n        self.to_out = nn.Sequential(\n            nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)\n        )\n\n    def forward(self, x):\n        x = self.norm(x)\n\n        q, k, v = rearrange(\n            self.to_qkv(x), \"b n (qkv h d) -> qkv b h n d\", qkv=3, h=self.heads\n        )\n\n        if exists(self.rotary_embed):\n            q = self.rotary_embed.rotate_queries_or_keys(q)\n            k = self.rotary_embed.rotate_queries_or_keys(k)\n\n        out = self.attend(q, k, v)\n\n        if exists(self.to_gates):\n            gates = self.to_gates(x)\n            out = out * rearrange(gates, \"b n h -> b h n 1\").sigmoid()\n\n        out = rearrange(out, \"b h n d -> b n (h d)\")\n        return self.to_out(out)\n\n\n# Roformer\n\n\nclass Transformer(Module):\n    def __init__(\n        self,\n        *,\n        dim,\n        depth,\n        dim_head=32,\n        heads=16,\n        attn_dropout=0.1,\n        ff_dropout=0.1,\n        ff_mult=4,\n        norm_output=True,\n        rotary_embed=None,\n        gating=True,\n    ):\n        super().__init__()\n        self.layers = ModuleList([])\n\n        for _ in range(depth):\n            ff = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)\n            self.layers.append(\n                ModuleList(\n                    [\n                        Attention(\n                            dim=dim,\n                            dim_head=dim_head,\n                            heads=heads,\n                            dropout=attn_dropout,\n                            rotary_embed=rotary_embed,\n                            gating=gating,\n                        ),\n                        ff,\n                    ]\n                )\n            )\n\n        self.norm = RMSNorm(dim) if norm_output else nn.Identity()\n\n    def forward(self, x):\n        for attn, ff in self.layers:\n            x = attn(x) + x\n            x = ff(x) + x\n        x = self.norm(x)\n        return x\n"
  },
  {
    "path": "beat_this/preprocessing.py",
    "content": "import numpy as np\nimport torch\nimport torchaudio\n\n\ndef load_audio(path, dtype=\"float64\"):\n    try:\n        waveform, samplerate = torchaudio.load(path, channels_first=False)\n        waveform = np.asanyarray(waveform.squeeze().numpy(), dtype=dtype)\n        return waveform, samplerate\n    except Exception:\n        # in case torchaudio fails, try soundfile\n        try:\n            import soundfile as sf\n\n            return sf.read(path, dtype=dtype)\n        except Exception:\n            # some files are not readable by soundfile, try madmom\n            try:\n                import madmom\n\n                return madmom.io.load_audio_file(str(path), dtype=dtype)\n            except Exception:\n                raise RuntimeError(f'Could not load audio from \"{path}\".')\n\n\nclass LogMelSpect(torch.nn.Module):\n    def __init__(\n        self,\n        sample_rate=22050,\n        n_fft=1024,\n        hop_length=441,\n        f_min=30,\n        f_max=11000,\n        n_mels=128,\n        mel_scale=\"slaney\",\n        normalized=\"frame_length\",\n        power=1,\n        log_multiplier=1000,\n        device=\"cpu\",\n    ):\n        super().__init__()\n        self.spect_class = torchaudio.transforms.MelSpectrogram(\n            sample_rate=sample_rate,\n            n_fft=n_fft,\n            hop_length=hop_length,\n            f_min=f_min,\n            f_max=f_max,\n            n_mels=n_mels,\n            mel_scale=mel_scale,\n            normalized=normalized,\n            power=power,\n        ).to(device)\n        self.log_multiplier = log_multiplier\n\n    def forward(self, x):\n        \"\"\"Input is a waveform as a monodimensional array of shape T,\n        output is a 2D log mel spectrogram of shape (F,128).\"\"\"\n        return torch.log1p(self.log_multiplier * self.spect_class(x).T)\n"
  },
  {
    "path": "beat_this/utils.py",
    "content": "from itertools import chain\nfrom pathlib import Path\n\nimport numpy as np\n\n\ndef index_to_framewise(index, length):\n    \"\"\"Convert an index to a framewise sequence\"\"\"\n    sequence = np.zeros(length, dtype=bool)\n    sequence[index] = True\n    return sequence\n\n\ndef filename_to_augmentation(filename):\n    \"\"\"Convert a filename to an augmentation factor.\"\"\"\n    parts = Path(filename).stem.split(\"_\")\n    augmentations = {}\n    for part in parts[1:]:\n        if part.startswith(\"ps\"):\n            augmentations[\"shift\"] = int(part[2:])\n        elif part.startswith(\"ts\"):\n            augmentations[\"stretch\"] = int(part[2:])\n    return augmentations\n\n\ndef infer_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray:\n    \"\"\"\n    From beat and downbeat times, infer a number for each beat such that each downbeat\n    is associated with a 1 and beats in between are counted upwards.\n    The function requires that all downbeats are also listed as beats.\n\n    Args:\n        beats (numpy.ndarray): Array of beat positions in seconds (including downbeats).\n        downbeats (numpy.ndarray): Array of downbeat positions in seconds.\n\n    Returns:\n        numbers (numpy.ndarray): Array of integer beat numbers.\n    \"\"\"\n    # check if all downbeats are beats\n    if not np.all(np.isin(downbeats, beats)):\n        raise ValueError(\"Not all downbeats are beats.\")\n\n    # handle pickup measure, by considering the beat count of the first full measure\n    if len(downbeats) >= 2:\n        # find the number of beats between the first two downbeats\n        first_downbeat, second_downbeat = np.searchsorted(beats, downbeats[:2])\n        beats_in_first_measure = second_downbeat - first_downbeat\n        # find the number of beats before the first downbeat\n        pickup_beats = first_downbeat\n        # derive where to start counting\n        if pickup_beats < beats_in_first_measure:\n            start_counter = beats_in_first_measure - pickup_beats\n        else:\n            print(\n                \"WARNING: There are more beats in the pickup measure than in the first measure. The beat count will start from 2 without trying to estimate the length of the pickup measure.\"\n            )\n            start_counter = 1\n    else:\n        print(\n            \"WARNING: There are less than two downbeats in the predictions. Something may be wrong. The beat count will start from 2 without trying to estimate the length of the pickup measure.\"\n        )\n        start_counter = 1\n\n    # assemble the beat numbers\n    numbers = []\n    counter = start_counter\n    downbeats = chain(downbeats, [-1])\n    next_downbeat = next(downbeats)\n    for beat in beats:\n        if beat == next_downbeat:\n            counter = 1\n            next_downbeat = next(downbeats)\n        else:\n            counter += 1\n        numbers.append(counter)\n    return np.asarray(numbers)\n\n\ndef save_beat_tsv(beats: np.ndarray, downbeats: np.ndarray, outpath: str) -> None:\n    \"\"\"\n    Save beat information to a tab-separated file in the standard .beats format:\n    each line has a time in seconds, a tab, and a beat number (1 = downbeat).\n    The function requires that all downbeats are also listed as beats.\n\n    Args:\n        beats (numpy.ndarray): Array of beat positions in seconds (including downbeats).\n        downbeats (numpy.ndarray): Array of downbeat positions in seconds.\n        outpath (str): Path to the output TSV file.\n\n    Returns:\n        None\n    \"\"\"\n    # infer beat numbers\n    numbers = infer_beat_numbers(beats, downbeats)\n\n    # write the beat file\n    Path(outpath).parent.mkdir(parents=True, exist_ok=True)\n    try:\n        with open(outpath, \"w\") as f:\n            f.writelines(f\"{beat}\\t{number}\\n\" for beat, number in zip(beats, numbers))\n    except KeyboardInterrupt:\n        outpath.unlink()  # avoid half-written files\n\n\ndef replace_state_dict_key(state_dict: dict, old: str, new: str):\n    \"\"\"Replaces `old` in all keys of `state_dict` with `new`.\"\"\"\n    keys = list(state_dict.keys())  # take snapshot of the keys\n    for key in keys:\n        if old in key:\n            state_dict[key.replace(old, new)] = state_dict.pop(key)\n    return state_dict\n"
  },
  {
    "path": "beat_this_example.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"provenance\": [],\n      \"authorship_tag\": \"ABX9TyOW4OkTmphTrvw2IQLr+kxP\",\n      \"include_colab_link\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    }\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"view-in-github\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"<a href=\\\"https://colab.research.google.com/github/CPJKU/beat_this/blob/main/beat_this_example.ipynb\\\" target=\\\"_parent\\\"><img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/></a>\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"# Beat This! inference example\\n\",\n        \"\\n\",\n        \"We first need to install and load the package.\"\n      ],\n      \"metadata\": {\n        \"id\": \"87X_GXfoGwmj\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"sxhsMCKdLOLO\",\n        \"collapsed\": true\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# install the beat_this package\\n\",\n        \"!pip install https://github.com/CPJKU/beat_this/archive/main.zip\\n\",\n        \"# on Google Colab, this one is faster:\\n\",\n        \"#!pip install --no-deps rotary-embedding-torch https://github.com/CPJKU/beat_this/archive/main.zip\\n\",\n        \"\\n\",\n        \"# load the Python class for beat tracking\\n\",\n        \"from beat_this.inference import File2Beats\\n\",\n        \"from beat_this.inference import File2File\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Run on demo file\\n\",\n        \"\\n\",\n        \"Now that all the dependencies have been installed and imported, let's run our system.\\n\",\n        \"\\n\",\n        \"In the next cell we:\\n\",\n        \"- define the audio file we want to use as input. For now we use the example provided in the beat_this repo, but this can be changed (see instructions later);\\n\",\n        \"- load the File2Beats class that produce a list of beats and downbeats given an audio file;\\n\",\n        \"- apply the class to the audio file\\n\",\n        \"- print the position in seconds of the first 20 beats and first 20 downbeats.\\n\"\n      ],\n      \"metadata\": {\n        \"id\": \"_0oYbH6P6Ji7\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"!wget -c \\\"https://github.com/CPJKU/beat_this/raw/main/tests/It%20Don't%20Mean%20A%20Thing%20-%20Kings%20of%20Swing.mp3\\\"\\n\",\n        \"audio_path = \\\"/content/It Don't Mean A Thing - Kings of Swing.mp3\\\"\\n\",\n        \"\\n\",\n        \"file2beats = File2Beats(checkpoint_path=\\\"final0\\\", dbn=False)\\n\",\n        \"beats, downbeats = file2beats(audio_path)\\n\",\n        \"\\n\",\n        \"print(\\\"First 20 beats\\\", beats[:20])\\n\",\n        \"print(\\\"First 20 downbeats\\\", downbeats[:20])\"\n      ],\n      \"metadata\": {\n        \"id\": \"DHT6v-a-TbZx\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"We can sonify the beats and downbeats as click on top of the audio file.\"\n      ],\n      \"metadata\": {\n        \"id\": \"lRjJFiexDGdn\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import IPython.display as ipd\\n\",\n        \"import librosa\\n\",\n        \"import numpy as np\\n\",\n        \"import soundfile as sf\\n\",\n        \"\\n\",\n        \"audio, sr = sf.read(audio_path)\\n\",\n        \"# make it mono if stereo\\n\",\n        \"if len(audio.shape) > 1:\\n\",\n        \"  audio = np.mean(audio, axis=1)\\n\",\n        \"\\n\",\n        \"# sonify the beats and downbeats\\n\",\n        \"# remove the beats that are also downbeats for a nicer sonification\\n\",\n        \"beats = [b for b in beats if b not in downbeats]\\n\",\n        \"audio_beat = librosa.clicks(times = beats, sr=sr, click_freq=1000, length=len(audio))\\n\",\n        \"audio_downbeat = librosa.clicks(times = downbeats, sr=sr, click_freq=1500, length=len(audio))\\n\",\n        \"\\n\",\n        \"ipd.display(ipd.Audio(audio + audio_beat + audio_downbeat, rate=sr))\"\n      ],\n      \"metadata\": {\n        \"id\": \"otG0NS_uCXSo\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Run on your own file\\n\",\n        \"\\n\",\n        \"If you want to run on your own audio files follow the following instructions:\\n\",\n        \"1. Click on the folder icon in the left vertical menu.\\n\",\n        \"2. Click on the \\\"Upload to session storage\\\" icon with the upward pointing arrow.\\n\",\n        \"\\n\",\n        \"    This will add an audio file to the current colab runtime (it could take some time, and you may need to refresh the file manager using the dedicated button to see the new file). You can copy the audio path by clicking on the three dots next to the file, then \\\"copy path\\\".\\n\",\n        \"\\n\",\n        \"    For example, if you upload a file called `my_song.mp3`, the path will be `/content/my_song.mp3`.\\n\",\n        \"\\n\",\n        \"3. change the `audio_path` in the cell above with the path of your uploaded audio\"\n      ],\n      \"metadata\": {\n        \"id\": \"hn83Sn1pWmt5\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"You can also produce a list of beat and downbeat as tsv file, that you can download and import in Sonic Visualizer.\\n\",\n        \"\\n\",\n        \"To do this this, use the File2File function as below:\"\n      ],\n      \"metadata\": {\n        \"id\": \"kP2gyplIEcWT\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"file2file = File2File(checkpoint_path=\\\"final0\\\", dbn=False)\\n\",\n        \"file2file(audio_path,output_path=\\\"output.beats\\\")\"\n      ],\n      \"metadata\": {\n        \"id\": \"kTQK-d4JEbL7\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"As you can see, the system is fast enough to work in a reasonable time even on CPU.\\n\",\n        \"\\n\",\n        \"For even faster inference, you can start a GPU session in Colab!\"\n      ],\n      \"metadata\": {\n        \"id\": \"1Y1d-DvXFtVz\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Batch processing multiple files\\n\",\n        \"\\n\",\n        \"To process multiple of your own audio files, upload them as described above, then run the `beat_this` command line tool:\"\n      ],\n      \"metadata\": {\n        \"id\": \"vpoM0RvQdAMF\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"!beat_this --model final0 /content/\"\n      ],\n      \"metadata\": {\n        \"id\": \"qNOLbBplc_Nq\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"It will produce a `.beats` file for every audio file that you can download again.\"\n      ],\n      \"metadata\": {\n        \"id\": \"_xNY_9DEdSEt\"\n      }\n    }\n  ]\n}"
  },
  {
    "path": "hubconf.py",
    "content": "dependencies = [\n    \"torch\",\n    \"torchaudio\",\n    \"numpy\",\n    \"rotary_embedding_torch\",\n    \"einops\",\n    \"soxr\",\n]\n\nfrom beat_this.inference import (\n    load_model as beat_this,\n    BeatThis,\n    Spect2Frames,\n    Audio2Frames,\n    Audio2Beats,\n    File2Beats,\n    File2File,\n)\n"
  },
  {
    "path": "launch_scripts/clean_checkpoints.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport torch\n\n\ndef main(args):\n    # check if output path exists\n    if Path(args.output_path).exists():\n        print(f\"Output path {args.output_path} already exists. Exiting.\")\n        return\n\n    # load the lightning checkpoit\n    checkpoint = torch.load(args.input_path, map_location=\"cpu\")\n\n    # clean and keep only the keys \"state_dict\" and \"datamodule\" to save space\n    checkpoint = {\n        k: v\n        for k, v in checkpoint.items()\n        if k\n        in [\n            \"state_dict\",\n            \"datamodule_hyper_parameters\",\n            \"hyper_parameters\",\n            \"pytorch-lightning_version\",\n        ]\n    }\n\n    # remove the \"data_dir\" key from \"datamodule_hyper_parameters\" because it is a\n    # Posix path and creates problems when loading in Windows.\n    if \"data_dir\" in checkpoint[\"datamodule_hyper_parameters\"]:\n        del checkpoint[\"datamodule_hyper_parameters\"][\"data_dir\"]\n\n    # save the cleaned checkpoint\n    torch.save(checkpoint, args.output_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--input-path\", type=str, required=True)\n    parser.add_argument(\"--output-path\", type=str, required=True)\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "launch_scripts/compute_paper_metrics.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nfrom pathlib import Path\n\nimport numpy as np\nfrom pytorch_lightning import Trainer, seed_everything\n\nfrom beat_this.dataset import BeatDataModule\nfrom beat_this.inference import load_checkpoint\nfrom beat_this.model.pl_module import PLBeatThis\nfrom beat_this.utils import infer_beat_numbers\n\n# for repeatability\nseed_everything(0, workers=True)\n\n\ndef main(args):\n    if len(args.models) == 1:\n        print(\"Single model prediction for\", args.models[0])\n        # single model prediction\n        checkpoint_path = args.models[0]\n        checkpoint = load_checkpoint(checkpoint_path)\n\n        # create datamodule\n        datamodule = datamodule_setup(checkpoint, args.num_workers, args.datasplit)\n        # create model and trainer\n        model, trainer = plmodel_setup(\n            checkpoint, args.eval_trim_beats, args.dbn, args.gpu\n        )\n        # predict\n        metrics, dataset, preds, piece = compute_predictions(\n            model,\n            trainer,\n            datamodule.predict_dataloader(),\n            return_preds=args.dump_predictions,\n        )\n\n        # compute averaged metrics\n        averaged_metrics = {k: np.mean(v) for k, v in metrics.items()}\n        # compute metrics averaged by dataset\n        dataset_metrics = {\n            k: {d: np.mean(v[dataset == d]) for d in np.unique(dataset)}\n            for k, v in metrics.items()\n        }\n        # print for dataset\n        print(\"Metrics\")\n        for k, v in averaged_metrics.items():\n            print(f\"{k}: {v}\")\n        print(\"Dataset metrics\")\n        for k, v in dataset_metrics.items():\n            print(k)\n            for d, value in v.items():\n                print(f\"{d}: {value}\")\n            print(\"------\")\n        # dump predictions\n        if args.dump_predictions:\n            write_predictions(args.dump_predictions, preds, piece)\n    else:  # multiple models\n        if args.aggregation_type == \"mean-std\":\n            if args.dump_predictions:\n                print(\n                    \"cannot dump predictions when doing inference for multiple models\"\n                )\n                return\n            # computing result variability for the same dataset and different model seeds\n            # create datamodule only once, as we assume it is the same for all models\n            checkpoint = load_checkpoint(args.models[0])\n            datamodule = datamodule_setup(checkpoint, args.num_workers, args.datasplit)\n            # create model and trainer\n            all_metrics = []\n            for checkpoint_path in args.models:\n                checkpoint = load_checkpoint(checkpoint_path)\n                model, trainer = plmodel_setup(\n                    checkpoint, args.eval_trim_beats, args.dbn, args.gpu\n                )\n\n                metrics, dataset, preds, piece = compute_predictions(\n                    model, trainer, datamodule.predict_dataloader()\n                )\n                # compute averaged metrics for one model\n                averaged_metrics = {k: np.mean(v) for k, v in metrics.items()}\n                all_metrics.append(averaged_metrics)\n            # compute mean and standard deviations for all model averages\n            all_metrics_mean = {\n                k: np.mean([m[k] for m in all_metrics]) for k in all_metrics[0]\n            }\n            all_metrics_std = {\n                k: np.std([m[k] for m in all_metrics]) for k in all_metrics[0]\n            }\n            all_metrics_stats = {\n                k: (all_metrics_mean[k], all_metrics_std[k])\n                for k, v in all_metrics[0].items()\n            }\n            # print all metrics\n            print(\"Metrics\")\n            for k, v in all_metrics_stats.items():\n                # round to 3 decimal places\n                print(f\"{k}: {round(v[0],3)} +- {round(v[1],3)}\")\n        elif args.aggregation_type == \"k-fold\":\n            # computing results in the K-fold setting. Every fold has a different dataset\n            all_piece_metrics = []\n            all_piece_dataset = []\n            all_piece_preds = []\n            all_piece = []\n            # create datamodule for each model\n            for i_model, checkpoint_path in enumerate(args.models):\n                print(f\"Model {i_model+1}/{len(args.models)}\")\n                checkpoint = load_checkpoint(checkpoint_path)\n                datamodule = datamodule_setup(\n                    checkpoint, args.num_workers, args.datasplit\n                )\n                # create model and trainer\n                model, trainer = plmodel_setup(\n                    checkpoint, args.eval_trim_beats, args.dbn, args.gpu\n                )\n                # predict\n                metrics, dataset, preds, piece = compute_predictions(\n                    model,\n                    trainer,\n                    datamodule.predict_dataloader(),\n                    return_preds=args.dump_predictions,\n                )\n                all_piece_metrics.append(metrics)\n                all_piece_dataset.append(dataset)\n                all_piece_preds.extend(preds)\n                all_piece.append(piece)\n            # aggregate across folds\n            all_piece_metrics = {\n                k: np.concatenate([m[k] for m in all_piece_metrics])\n                for k in all_piece_metrics[0]\n            }\n            all_piece_dataset = np.concatenate(all_piece_dataset)\n            all_piece = np.concatenate(all_piece)\n            # double check that there are no errors in the fold and there are not repeated pieces\n            assert len(all_piece) == len(\n                np.unique(all_piece)\n            ), \"There are repeated pieces in the folds\"\n            dataset_metrics = {\n                k: {\n                    d: np.mean(v[all_piece_dataset == d])\n                    for d in np.unique(all_piece_dataset)\n                }\n                for k, v in all_piece_metrics.items()\n            }\n            # print for dataset\n            print(\"Dataset metrics\")\n            for k, v in dataset_metrics.items():\n                print(k)\n                for d, value in v.items():\n                    print(f\"{d}: {round(value,3)}\")\n                print(\"------\")\n            # dump predictions\n            if args.dump_predictions:\n                write_predictions(args.dump_predictions, all_piece_preds, all_piece)\n        else:\n            raise ValueError(f\"Unknown aggregation type {args.aggregation_type}\")\n\n\ndef datamodule_setup(checkpoint, num_workers, datasplit):\n    # Load the datamodule\n    print(\"Creating datamodule\")\n    data_dir = Path(__file__).parent.parent.relative_to(Path.cwd()) / \"data\"\n    datamodule_hparams = checkpoint[\"datamodule_hyper_parameters\"]\n    # update the hparams with the ones from the arguments\n    if num_workers is not None:\n        datamodule_hparams[\"num_workers\"] = num_workers\n    datamodule_hparams[\"predict_datasplit\"] = datasplit\n    datamodule_hparams[\"data_dir\"] = data_dir\n    datamodule = BeatDataModule(**datamodule_hparams)\n    datamodule.setup(stage=\"predict\")\n    return datamodule\n\n\ndef plmodel_setup(checkpoint, eval_trim_beats, dbn, gpu):\n    \"\"\"\n    Set up the pytorch lightning model and trainer for evaluation.\n\n    Args:\n        checkpoint_path (dict): The dict containing the checkpoint to load.\n        eval_trim_beats (int or None): The number of beats to trim during evaluation. If None, the setting is taken from the pretrained model.\n        dbn (bool or None): Whether to use the Dynamic Bayesian Network (DBN) module during evaluation. If None, the default behavior from the pretrained model is used.\n        gpu (int): The index of the GPU device to use for training.\n\n    Returns:\n        tuple: A tuple containing the initialized pytorch lightning model and trainer.\n\n    \"\"\"\n    if eval_trim_beats is not None:\n        checkpoint[\"hyper_parameters\"][\"eval_trim_beats\"] = eval_trim_beats\n    if dbn is not None:\n        checkpoint[\"hyper_parameters\"][\"use_dbn\"] = dbn\n\n    model = PLBeatThis(**checkpoint[\"hyper_parameters\"])\n    model.load_state_dict(checkpoint[\"state_dict\"])\n    # set correct device and accelerator\n    if gpu >= 0:\n        devices = [gpu]\n        accelerator = \"gpu\"\n    else:\n        devices = 1\n        accelerator = \"cpu\"\n    # create trainer\n    trainer = Trainer(\n        accelerator=accelerator,\n        devices=devices,\n        logger=None,\n        deterministic=True,\n        precision=\"16-mixed\",\n    )\n    return model, trainer\n\n\ndef compute_predictions(model, trainer, predict_dataloader, return_preds=False):\n    print(\"Computing predictions ...\")\n    out = trainer.predict(model, predict_dataloader)\n    metrics = [o[0] for o in out]\n    if return_preds:\n        preds = [model.postprocessor(o[1][\"beat\"][0], o[1][\"downbeat\"][0]) for o in out]\n    else:\n        preds = None\n    dataset = np.asarray([o[2][0] for o in out])\n    piece = np.asarray([o[3][0] for o in out])\n    # convert metrics from list of per-batch dictionaries to a single dictionary with np arrays as values\n    metrics = {k: np.asarray([m[k] for m in metrics]) for k in metrics[0]}\n    return metrics, dataset, preds, piece\n\n\ndef write_predictions(fn, preds, piece):\n    np.savez(\n        fn,\n        **{\n            name: np.vstack([beats, infer_beat_numbers(beats, downbeats)]).T\n            for name, (beats, downbeats) in zip(piece, preds)\n        },\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description=\"Computes predictions for a given model and dataset, \"\n        \"prints metrics, and optionally dumps predictions to a given file.\"\n    )\n    parser.add_argument(\n        \"--models\",\n        type=str,\n        nargs=\"+\",\n        required=True,\n        help=\"Local checkpoint files to use\",\n    )\n    parser.add_argument(\n        \"--datasplit\",\n        type=str,\n        choices=(\"train\", \"val\", \"test\"),\n        default=\"val\",\n        help=\"data split to use: train, val or test \" \"(default: %(default)s)\",\n    )\n    parser.add_argument(\"--gpu\", type=int, default=0)\n    parser.add_argument(\n        \"--num_workers\", type=int, default=8, help=\"number of data loading workers \"\n    )\n    parser.add_argument(\n        \"--eval_trim_beats\",\n        metavar=\"SECONDS\",\n        type=float,\n        default=None,\n        help=\"Override whether to skip the first given seconds \"\n        \"per piece in evaluating (default: as stored in model)\",\n    )\n    parser.add_argument(\n        \"--dbn\",\n        default=None,\n        action=argparse.BooleanOptionalAction,\n        help=\"override the option to use madmom postprocessing dbn\",\n    )\n    parser.add_argument(\n        \"--aggregation-type\",\n        type=str,\n        choices=(\"mean-std\", \"k-fold\"),\n        default=\"mean-std\",\n        help=\"Type of aggregation to use for multiple models; ignored if only one model is given\",\n    )\n    parser.add_argument(\n        \"--dump-predictions\",\n        metavar=\"FILENAME\",\n        type=str,\n        default=None,\n        help=\"File to write predictions to, in .npz format (optional)\",\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "launch_scripts/preprocess_audio.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport concurrent.futures\nimport os\nfrom pathlib import Path\nfrom zipfile import ZipFile\n\nimport numpy as np\nimport pandas as pd\nimport soxr\nimport torch\nimport torchaudio\nfrom pedalboard import Pedalboard, PitchShift, time_stretch\nfrom tqdm import tqdm\n\nfrom beat_this.dataset.augment import precomputed_augmentation_filenames\nfrom beat_this.preprocessing import LogMelSpect, load_audio\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n\nBASEPATH = Path(__file__).parent.parent.relative_to(Path.cwd())\n\n\ndef save_audio(path, waveform, samplerate, resample_from=None):\n    if resample_from and resample_from != samplerate:\n        waveform = soxr.resample(waveform, in_rate=resample_from, out_rate=samplerate)\n    try:\n        waveform = torch.as_tensor(np.asarray(waveform, dtype=np.float64))\n        torchaudio.save(\n            path, torch.atleast_2d(waveform), samplerate, bits_per_sample=16\n        )\n    except KeyboardInterrupt:\n        path.unlink()  # avoid half-written files\n        raise\n\n\ndef save_spectrogram(path, spectrogram, dtype=np.float16):\n    try:\n        np.save(path, np.asarray(spectrogram, dtype=dtype))\n    except KeyboardInterrupt:\n        path.unlink()  # avoid half-written files\n        raise\n\n\nclass SpectCreation:\n    def __init__(self, pitch_shift, time_stretch, audio_sr, mel_args, verbose=False):\n        \"\"\"\n        Initialize the SpectCreation class. This assume that the audio files have been preprocessed with all the requested augmentations and are stored in the `mono_tracks` directory with the proper naming defined in AudioPreprocessing.\n\n        Args:\n            pitch_shift (tuple or None): A tuple specifying the minimum and maximum (inclusive) pitch shift values considered from the available audio files.\n                                        If None, pitch shifting augmentation files will not be considered.\n            time_stretch (tuple or None): A tuple specifying the min/max and stride percentage to consider from the available audio files.\n                                        If None, time stretching augmentation files will not be considered.\n            audio_sr (int): The sample rate of the audio.\n            mel_args (dict): A dictionary of arguments to be passed to the MelSpectrogram class.\n            verbose (bool, optional): Whether to print verbose information. Defaults to False.\n        \"\"\"\n        super(SpectCreation, self).__init__()\n        # define the directories\n        self.audio_dir = BASEPATH / \"data\" / \"audio\"\n        self.mono_tracks_dir = self.audio_dir / \"mono_tracks\"\n        self.spectrograms_dir = self.audio_dir / \"spectrograms\"\n        self.annotations_dir = BASEPATH / \"data\" / \"annotations\"\n\n        if verbose:\n            print(\"Audio dir: \", self.audio_dir.absolute())\n            print(\"Mono tracks dir: \", self.mono_tracks_dir.absolute())\n            print(\"Spectrograms dir: \", self.spectrograms_dir.absolute())\n            print(\"Annotations dir: \", self.annotations_dir.absolute())\n        self.verbose = verbose\n        # remember the audio metadata\n        self.audio_sr = audio_sr\n        # create the mel spectrogram class\n        self.logspect_class = LogMelSpect(audio_sr, **mel_args)\n        # define the augmentations\n        self.augmentations = {}\n        if pitch_shift is not None:\n            self.augmentations[\"pitch\"] = {\"min\": pitch_shift[0], \"max\": pitch_shift[1]}\n        if time_stretch is not None:\n            self.augmentations[\"tempo\"] = {\n                \"min\": -time_stretch[0],\n                \"max\": time_stretch[0],\n                \"stride\": time_stretch[1],\n            }\n        # compute the names to consider according to the augmentations\n        self.filenames = precomputed_augmentation_filenames(self.augmentations, \"wav\")\n\n    def create_spects(self):\n        print(\"Creating spectrograms ...\")\n        processed = 0\n        with concurrent.futures.ThreadPoolExecutor() as executor:\n            futures = []\n            for dataset_dir in self.mono_tracks_dir.iterdir():\n                for piece_dir in dataset_dir.iterdir():\n                    futures.append(\n                        executor.submit(\n                            self.create_spect_piece,\n                            piece_dir,\n                            Path(dataset_dir.name)\n                            / \"annotations\"\n                            / \"beats\"\n                            / f\"{piece_dir.name}.beats\",\n                            dataset_dir.name,\n                        )\n                    )\n            for future in tqdm(\n                concurrent.futures.as_completed(futures), total=len(futures)\n            ):\n                if future.result():\n                    processed += 1\n        print(f\"Created {processed} spectrograms in {self.spectrograms_dir}\")\n\n    def create_spect_piece(self, preprocessed_audio_folder, beat_path, dataset_name):\n        \"\"\"\n        Create spectrogram for a single audio piece.\n\n        This method creates a spectrogram for a single audio piece located in the `preprocessed_audio_folder`.\n        The beat annotations for the audio piece are loaded from the `beat_path` file.\n        The created spectrogram is saved in the `spectrograms_dir` directory.\n\n        Args:\n            preprocessed_audio_folder (Path): The path to the preprocessed audio folder.\n            beat_path (Path): The path to the beat annotations file.\n            dataset_name (str): The name of the dataset.\n\n        Returns:\n            metadata (list): A list containing the metadata of the created spectrogram.\n        \"\"\"\n        for filename in self.filenames:\n            if not (self.annotations_dir / beat_path).exists():\n                print(\n                    f\"beat annotation {beat_path} not found for {preprocessed_audio_folder}\"\n                )\n                return\n            audio_path = preprocessed_audio_folder / filename\n            spect_path = (\n                self.spectrograms_dir\n                / dataset_name\n                / preprocessed_audio_folder.name\n                / f\"{Path(filename).stem}.npy\"\n            )\n            if spect_path.exists():\n                if self.verbose:\n                    print(f\"Skipping {spect_path} because it exists\")\n            else:\n                if self.verbose:\n                    print(f\"Computing {spect_path}\")\n                waveform, sr = load_audio(audio_path)\n                assert (\n                    sr == self.audio_sr\n                ), f\"Sample rate mismatch: {sr} != {self.audio_sr}\"\n                # compute the mel spectrogram and scale the values with log(1 + 1000 * x)\n                spect = self.logspect_class(torch.tensor(waveform, dtype=torch.float32))\n                # save the spectrogram as numpy array\n                spect_path.parent.mkdir(parents=True, exist_ok=True)\n                save_spectrogram(spect_path, spect.numpy())\n        return True\n\n\nclass AudioPreprocessing(object):\n    def __init__(\n        self,\n        orig_audio_paths,\n        out_sr=22050,\n        aug_sr=44100,\n        ext=\"wav\",\n        pitch_shift=(-5, 6),\n        time_stretch=(20, 4),\n        verbose=False,\n    ):\n        \"\"\"\n        Class for converting audio files to mono, resampling, and applying augmentations.\n        Only use this if you want to start from new audio files, otherwise use the spectrograms provided in the repo.\n\n        Args:\n            orig_audio_paths (Path): The path to the file with the original audio paths for each dataset.\n            out_sr (int, optional): The output sample rate. Defaults to 22050.\n            aug_sr (int, optional): The sample rate for the augmentations. Defaults to 44100.\n            ext (str, optional): The extension of the audio files. Defaults to 'wav'.\n            pitch_shift (tuple, optional): A tuple specifying the minimum and maximum (inclusive) pitch shift values considered. Defaults to (-5, 6).\n            time_stretch (tuple, optional): A tuple specifying the min/max (inclusive) time stretch and stride in percentage considered. Defaults to (20, 4).\n            verbose (bool, optional): Whether to print verbose information. Defaults to False.\n        \"\"\"\n        super(AudioPreprocessing, self).__init__()\n        self.audio_dir = BASEPATH / \"data\" / \"audio\"\n        self.annotation_dir = BASEPATH / \"data\" / \"annotations\"\n        # load data_dir from audio_path.csv which has the format: dataset_name, audio_path\n        self.audio_dirs = {\n            row[0]: row[1] for row in pd.read_csv(orig_audio_paths, header=None).values\n        }\n        # check if annotations exists, otherwise tell how to obtain them\n        if not self.annotation_dir.exists():\n            raise RuntimeError(\n                f\"{self.annotation_dir} missing, check instructions \"\n                \"in README.md how to obtain the annotations.\"\n            )\n\n        print(f\"Annotations ready in {self.annotation_dir}\")\n\n        self.out_sr = out_sr\n        self.aug_sr = aug_sr\n        self.ext = ext\n        self.pitch_shift = pitch_shift\n        if time_stretch:\n            # interpret tuple as (maximum percentage, stride)\n            time_stretch = range(\n                -time_stretch[0],\n                time_stretch[0] + 1,\n                time_stretch[1] if len(time_stretch) > 1 else 1,\n            )\n        self.time_stretch = time_stretch\n        self.verbose = verbose\n\n    def preprocess_audio(self):\n        print(\"Preprocessing audio files ...\")\n        processed = 0\n        with concurrent.futures.ThreadPoolExecutor() as executor:\n            futures = []\n            for dataset_name, audio_dir in self.audio_dirs.items():\n                for audio_path in Path(audio_dir).iterdir():\n                    if audio_path.stem[:12] in (\"gtzan_speech\", \"gtzan_music_\"):\n                        continue\n                    futures.append(\n                        executor.submit(\n                            self.process_audio_file, dataset_name, audio_path\n                        )\n                    )\n            for future in tqdm(\n                concurrent.futures.as_completed(futures), total=len(futures)\n            ):\n                if future.result():\n                    processed += 1\n        print(\"Processed\", processed, \"audio files\")\n\n    def process_audio_file(self, dataset_name, audio_path):\n        annotation_dir = Path(self.annotation_dir, dataset_name, \"annotations\")\n        # load annotations\n        beat_path = Path(annotation_dir, \"beats\", audio_path.stem + \".beats\")\n        if not beat_path.exists():\n            print(\n                f\"beat annotation {beat_path} not found for {audio_path}\",\n            )\n            return False\n        # create a folder with the name of the track\n        folder_path = Path(self.audio_dir, \"mono_tracks\", dataset_name, audio_path.stem)\n        # derive the name of the unaugmented file\n        mono_path = folder_path / f\"track.{self.ext}\"\n        # derive the name of all augmented files\n        augmentations = {\n            \"pitch\": {\"min\": self.pitch_shift[0], \"max\": self.pitch_shift[1]},\n            \"tempo\": {\n                \"min\": -self.time_stretch[0],\n                \"max\": self.time_stretch[0],\n                \"stride\": self.time_stretch[1],\n            },\n        }\n        augmentations_path = precomputed_augmentation_filenames(augmentations, self.ext)\n        # stop here if all files exists\n        if mono_path.exists() and all(\n            (folder_path / aug).exists() for aug in augmentations_path\n        ):\n            if self.verbose:\n                print(f\"All files in {folder_path} exists, skipping\")\n            return True\n\n        # load audio\n        try:\n            waveform, sr = load_audio(audio_path)\n        except Exception as e:\n            print(\"Problem with loading waveform\", audio_path, e)\n            return\n        folder_path.mkdir(parents=True, exist_ok=True)\n        if (\n            waveform.ndim == 1\n            and sr == self.out_sr\n            and audio_path.suffix == f\".{self.ext}\"\n        ):\n            # shortcut: copy original file to mono path location\n            os.system(\"cp '{}' '{}'\".format(audio_path, mono_path))\n        else:\n            # we need to do some conversions for the unaugmented file\n            if waveform.ndim != 1:\n                waveform = np.mean(waveform, axis=1)\n            if not mono_path.exists():\n                if sr != self.out_sr:\n                    waveform_out = soxr.resample(\n                        waveform, in_rate=sr, out_rate=self.out_sr\n                    )\n                else:\n                    waveform_out = waveform\n                # save mono file\n                save_audio(mono_path, waveform_out, self.out_sr)\n        if (self.pitch_shift or self.time_stretch) and (sr != self.aug_sr):\n            waveform = soxr.resample(waveform, in_rate=sr, out_rate=self.aug_sr)\n\n        # handle the requested augmentations\n        # pedalboard requires float32, convert\n        waveform = np.asarray(waveform, dtype=np.float32)\n        shifts = (\n            range(self.pitch_shift[0], self.pitch_shift[1] + 1)\n            if self.pitch_shift\n            else [0]\n        )\n        stretches = self.time_stretch if self.time_stretch else [0]\n        for shift in shifts:  # pitch augmentation\n            augment_audio_file(\n                folder_path,\n                waveform,\n                aug_type=\"shift\",\n                amount=shift,\n                aug_sr=self.aug_sr,\n                out_sr=self.out_sr,\n                ext=self.ext,\n                verbose=self.verbose,\n            )\n        for stretch in stretches:  # tempo augmentation\n            augment_audio_file(\n                folder_path,\n                waveform,\n                aug_type=\"stretch\",\n                amount=stretch,\n                aug_sr=self.aug_sr,\n                out_sr=self.out_sr,\n                ext=self.ext,\n                verbose=self.verbose,\n            )\n\n        return True\n\n\ndef augment_audio_file(\n    folder_path, waveform, aug_type, amount, aug_sr, out_sr, ext, verbose\n):\n    # figure out the file name\n    if aug_type == \"stretch\":\n        stretch = amount\n        shift = 0\n    elif aug_type == \"shift\":\n        shift = amount\n        stretch = 0\n    else:\n        raise ValueError(f\"Unknown augmentation mode {aug_type}\")\n    suffix = \"\"\n    if shift != 0:\n        suffix = suffix + f\"_ps{shift}\"\n    if stretch != 0:\n        suffix = suffix + f\"_ts{stretch}\"\n    out_path = Path(folder_path, f\"track{suffix}.{ext}\")\n    # skip if it exists\n    if out_path.exists():\n        if verbose:\n            print(f\"{out_path} exists, skipping\")\n        return\n    # otherwise compute it and write it out\n    # time stretch or pitch shift alone\n    if aug_type == \"shift\":\n        if verbose:\n            print(f\"computing {out_path} with {shift=}\")\n        # pitch shift alone\n        board = Pedalboard(\n            [\n                PitchShift(semitones=shift),\n            ]\n        )\n        # apply pedalboard\n        augmented = board(waveform, aug_sr)\n    else:  # type == stretch\n        if verbose:\n            print(f\"computing {out_path} with {stretch=}\")\n        augmented = time_stretch(\n            waveform,\n            aug_sr,\n            stretch_factor=1 + stretch / 100,\n            pitch_shift_in_semitones=0.0,\n        ).squeeze()\n    # save to file\n    if verbose:\n        print(f\"writing {out_path}\")\n    save_audio(out_path, augmented, out_sr, resample_from=aug_sr)\n\n\ndef create_npz(spect_dir, npz_file, augmentations, verbose):\n    \"\"\"Assemble spectrograms from a directory into an .npz file.\"\"\"\n    if npz_file.exists():\n        if verbose:\n            print(f\"{npz_file} already exists, skipping\")\n        return\n    with ZipFile(npz_file, \"w\") as z:\n        for subdir in tqdm(sorted(spect_dir.iterdir()), leave=False):\n            if subdir.is_dir():\n                for fn in precomputed_augmentation_filenames(augmentations):\n                    z.write(subdir / fn, subdir.name + \"/\" + fn)\n\n\ndef ints(value):\n    \"\"\"Parse a string containing a colon-separated tuple of integers.\"\"\"\n    return value and tuple(map(int, value.split(\":\")))\n\n\ndef main(orig_audio_paths, pitch_shift, time_stretch, verbose):\n    # preprocess audio\n    dp = AudioPreprocessing(\n        orig_audio_paths=orig_audio_paths,\n        out_sr=22050,\n        aug_sr=44100,\n        pitch_shift=pitch_shift,\n        time_stretch=time_stretch,\n        verbose=verbose,\n    )\n    dp.preprocess_audio()\n\n    # compute spectrograms\n    mel_args = dict(\n        n_fft=1024,\n        hop_length=441,\n        f_min=30,\n        f_max=11000,\n        n_mels=128,\n        mel_scale=\"slaney\",\n        normalized=\"frame_length\",\n        power=1,\n    )\n    sc = SpectCreation(\n        pitch_shift=pitch_shift,\n        time_stretch=time_stretch,\n        audio_sr=22050,\n        mel_args=mel_args,\n        verbose=verbose,\n    )\n    sc.create_spects()\n\n    # assemble into NPZ files\n    print(\"Creating .npz spectrogram bundles...\")\n    spect_dirs = [child for child in sc.spectrograms_dir.iterdir() if child.is_dir()]\n    for spect_dir in tqdm(spect_dirs):\n        create_npz(\n            spect_dir,\n            spect_dir.with_suffix(\".npz\"),\n            {} if spect_dir.name == \"gtzan\" else sc.augmentations,\n            verbose,\n        )\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--orig_audio_paths\",\n        type=str,\n        help=\"path to the file with the original audio paths for each dataset (default: %(default)s)\",\n        default=\"data/audio_paths.csv\",\n    )\n    parser.add_argument(\n        \"--pitch_shift\",\n        metavar=\"LOW:HIGH\",\n        type=str,\n        default=\"-5:6\",\n        help=\"pitch shift in semitones (default: %(default)s)\",\n    )\n    parser.add_argument(\n        \"--time_stretch\",\n        metavar=\"MAX:STRIDE\",\n        type=str,\n        default=\"20:4\",\n        help=\"time stretch in percentage and stride (default: %(default)s)\",\n    )\n    parser.add_argument(\"--verbose\", action=\"store_true\", help=\"verbose output\")\n    args = parser.parse_args()\n\n    main(\n        args.orig_audio_paths,\n        ints(args.pitch_shift),\n        ints(args.time_stretch),\n        args.verbose,\n    )\n"
  },
  {
    "path": "launch_scripts/train.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport torch\nfrom pytorch_lightning import Trainer, seed_everything\nfrom pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint\nfrom pytorch_lightning.loggers import WandbLogger\n\nfrom beat_this.dataset import BeatDataModule\nfrom beat_this.model.pl_module import PLBeatThis\n\n\ndef main(args):\n    # for repeatability\n    seed_everything(args.seed, workers=True)\n\n    print(\"Starting a new run with the following parameters:\")\n    print(args)\n\n    params_str = f\"{'noval ' if not args.val else ''}{'hung ' if args.hung_data else ''}{'fold' + str(args.fold) + ' ' if args.fold is not None else ''}{args.loss}-h{args.transformer_dim}-aug{args.tempo_augmentation}{args.pitch_augmentation}{args.mask_augmentation}{' nosumH ' if not args.sum_head else ''}{' nopartialT ' if not args.partial_transformers else ''}\"\n    if args.logger == \"wandb\":\n        if args.resume_checkpoint and args.resume_id:\n            wandb_args = dict(id=args.resume_id, resume=\"must\")\n        else:\n            wandb_args = {}\n        logger = WandbLogger(\n            project=\"beat_this\", name=f\"{args.name} {params_str}\".strip(), **wandb_args\n        )\n    else:\n        logger = None\n\n    if args.force_flash_attention:\n        print(\"Forcing the use of the flash attention.\")\n        torch.backends.cuda.enable_flash_sdp(True)\n        torch.backends.cuda.enable_mem_efficient_sdp(False)\n        torch.backends.cuda.enable_math_sdp(False)\n\n    data_dir = Path(__file__).parent.parent.relative_to(Path.cwd()) / \"data\"\n    checkpoint_dir = (\n        Path(__file__).parent.parent.relative_to(Path.cwd()) / \"checkpoints\"\n    )\n    augmentations = {}\n    if args.tempo_augmentation:\n        augmentations[\"tempo\"] = {\"min\": -20, \"max\": 20, \"stride\": 4}\n    if args.pitch_augmentation:\n        augmentations[\"pitch\"] = {\"min\": -5, \"max\": 6}\n    if args.mask_augmentation:\n        # kind, min_count, max_count, min_len, max_len, min_parts, max_parts\n        augmentations[\"mask\"] = {\n            \"kind\": \"permute\",\n            \"min_count\": 1,\n            \"max_count\": 6,\n            \"min_len\": 0.1,\n            \"max_len\": 2,\n            \"min_parts\": 5,\n            \"max_parts\": 9,\n        }\n\n    datamodule = BeatDataModule(\n        data_dir,\n        batch_size=args.batch_size,\n        train_length=args.train_length,\n        spect_fps=args.fps,\n        num_workers=args.num_workers,\n        test_dataset=\"gtzan\",\n        length_based_oversampling_factor=args.length_based_oversampling_factor,\n        augmentations=augmentations,\n        hung_data=args.hung_data,\n        no_val=not args.val,\n        fold=args.fold,\n    )\n    datamodule.setup(stage=\"fit\")\n\n    # compute positive weights\n    pos_weights = datamodule.get_train_positive_weights(widen_target_mask=3)\n    print(\"Using positive weights: \", pos_weights)\n    dropout = {\n        \"frontend\": args.frontend_dropout,\n        \"transformer\": args.transformer_dropout,\n    }\n    pl_model = PLBeatThis(\n        spect_dim=128,\n        fps=50,\n        transformer_dim=args.transformer_dim,\n        ff_mult=4,\n        n_layers=args.n_layers,\n        stem_dim=32,\n        dropout=dropout,\n        lr=args.lr,\n        weight_decay=args.weight_decay,\n        pos_weights=pos_weights,\n        head_dim=32,\n        loss_type=args.loss,\n        warmup_steps=args.warmup_steps,\n        max_epochs=args.max_epochs,\n        use_dbn=args.dbn,\n        eval_trim_beats=args.eval_trim_beats,\n        sum_head=args.sum_head,\n        partial_transformers=args.partial_transformers,\n    )\n    for part in args.compile:\n        if hasattr(pl_model.model, part):\n            setattr(pl_model.model, part, torch.compile(getattr(pl_model.model, part)))\n            print(\"Will compile model\", part)\n        else:\n            raise ValueError(\"The model is missing the part\", part, \"to compile\")\n\n    callbacks = [LearningRateMonitor(logging_interval=\"step\")]\n    # save only the last model\n    callbacks.append(\n        ModelCheckpoint(\n            every_n_epochs=1,\n            dirpath=str(checkpoint_dir),\n            filename=f\"{args.name} S{args.seed} {params_str}\".strip(),\n        )\n    )\n\n    trainer = Trainer(\n        max_epochs=args.max_epochs,\n        accelerator=\"auto\",\n        devices=[args.gpu],\n        num_sanity_val_steps=1,\n        logger=logger,\n        callbacks=callbacks,\n        log_every_n_steps=1,\n        precision=\"16-mixed\",\n        accumulate_grad_batches=args.accumulate_grad_batches,\n        check_val_every_n_epoch=args.val_frequency,\n    )\n\n    trainer.fit(pl_model, datamodule, ckpt_path=args.resume_checkpoint)\n    trainer.test(pl_model, datamodule)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--name\", type=str, default=\"\")\n    parser.add_argument(\"--gpu\", type=int, default=0)\n    parser.add_argument(\n        \"--force-flash-attention\", default=False, action=argparse.BooleanOptionalAction\n    )\n    parser.add_argument(\n        \"--compile\",\n        action=\"store\",\n        nargs=\"*\",\n        type=str,\n        default=[\"frontend\", \"transformer_blocks\", \"task_heads\"],\n        help=\"Which model parts to compile, among frontend, transformer_encoder, task_heads\",\n    )\n    parser.add_argument(\"--n-layers\", type=int, default=6)\n    parser.add_argument(\"--transformer-dim\", type=int, default=512)\n    parser.add_argument(\n        \"--frontend-dropout\",\n        type=float,\n        default=0.1,\n        help=\"dropout rate to apply in the frontend\",\n    )\n    parser.add_argument(\n        \"--transformer-dropout\",\n        type=float,\n        default=0.2,\n        help=\"dropout rate to apply in the main transformer blocks\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=0.0008)\n    parser.add_argument(\"--weight-decay\", type=float, default=0.01)\n    parser.add_argument(\"--logger\", type=str, choices=[\"wandb\", \"none\"], default=\"none\")\n    parser.add_argument(\"--num-workers\", type=int, default=8)\n    parser.add_argument(\"--n-heads\", type=int, default=16)\n    parser.add_argument(\"--fps\", type=int, default=50, help=\"The spectrograms fps.\")\n    parser.add_argument(\n        \"--loss\",\n        type=str,\n        default=\"shift_tolerant_weighted_bce\",\n        choices=[\n            \"shift_tolerant_weighted_bce\",\n            \"fast_shift_tolerant_weighted_bce\",\n            \"weighted_bce\",\n            \"bce\",\n        ],\n        help=\"The loss to use\",\n    )\n    parser.add_argument(\n        \"--warmup-steps\", type=int, default=1000, help=\"warmup steps for optimizer\"\n    )\n    parser.add_argument(\n        \"--max-epochs\", type=int, default=100, help=\"max epochs for training\"\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=8, help=\"batch size for training\"\n    )\n    parser.add_argument(\"--accumulate-grad-batches\", type=int, default=8)\n    parser.add_argument(\n        \"--train-length\",\n        type=int,\n        default=1500,\n        help=\"maximum seq length for training in frames\",\n    )\n    parser.add_argument(\n        \"--dbn\",\n        default=False,\n        action=argparse.BooleanOptionalAction,\n        help=\"use madmom postprocessing DBN\",\n    )\n    parser.add_argument(\n        \"--eval-trim-beats\",\n        metavar=\"SECONDS\",\n        type=float,\n        default=5,\n        help=\"Skip the first given seconds per piece in evaluating (default: %(default)s)\",\n    )\n    parser.add_argument(\n        \"--val-frequency\",\n        metavar=\"N\",\n        type=int,\n        default=5,\n        help=\"validate every N epochs (default: %(default)s)\",\n    )\n    parser.add_argument(\n        \"--tempo-augmentation\",\n        default=True,\n        action=argparse.BooleanOptionalAction,\n        help=\"Use precomputed tempo aumentation\",\n    )\n    parser.add_argument(\n        \"--pitch-augmentation\",\n        default=True,\n        action=argparse.BooleanOptionalAction,\n        help=\"Use precomputed pitch aumentation\",\n    )\n    parser.add_argument(\n        \"--mask-augmentation\",\n        default=True,\n        action=argparse.BooleanOptionalAction,\n        help=\"Use online mask aumentation\",\n    )\n    parser.add_argument(\n        \"--sum-head\",\n        default=True,\n        action=argparse.BooleanOptionalAction,\n        help=\"Use SumHead instead of two separate Linear heads\",\n    )\n    parser.add_argument(\n        \"--partial-transformers\",\n        default=True,\n        action=argparse.BooleanOptionalAction,\n        help=\"Use Partial transformers in the frontend\",\n    )\n    parser.add_argument(\n        \"--length-based-oversampling-factor\",\n        type=float,\n        default=0.65,\n        help=\"The factor to oversample the long pieces in the dataset. Set to 0 to only take one excerpt for each piece.\",\n    )\n    parser.add_argument(\n        \"--val\",\n        default=True,\n        action=argparse.BooleanOptionalAction,\n        help=\"Train on all data, including validation data, escluding test data. The validation metrics will still be computed, but they won't carry any meaning.\",\n    )\n    parser.add_argument(\n        \"--hung-data\",\n        default=False,\n        action=argparse.BooleanOptionalAction,\n        help=\"Limit the training to Hung et al. data. The validation will still be computed on all datasets.\",\n    )\n    parser.add_argument(\n        \"--fold\",\n        type=int,\n        default=None,\n        help=\"If given, the CV fold number to *not* train on (0-based).\",\n    )\n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=0,\n        help=\"Seed for the random number generators.\",\n    )\n    parser.add_argument(\n        \"--resume-checkpoint\",\n        type=str,\n        default=None,\n        help=\"Resume training from a local checkpoint.\",\n    )\n    parser.add_argument(\n        \"--resume-id\",\n        type=str,\n        default=None,\n        help=\"When resuming with --resume-checkpoint, optionally provide the wandb id to continue logging to.\",\n    )\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"beat-this\"\nversion = \"1.1.0\"\ndescription = \"Beat This! beat tracker\"\nreadme = \"README.md\"\nclassifiers = [\n    \"Intended Audience :: Science/Research\",\n    \"Topic :: Multimedia :: Sound/Audio :: Analysis\",\n    \"Development Status :: 5 - Production/Stable\",\n    \"Programming Language :: Python :: 3\",\n]\nauthors = [\n    {name = \"Francesco Foscarin\", email = \"francesco.foscarin@jku.at\"},\n    {name = \"Jan Schlüter\", email = \"jan.schlueter@jku.at\"},\n]\nrequires-python = \">=3\"\ndependencies = [\n    \"numpy>=1.20\",\n    \"torch>=2\",\n    \"torchaudio\",\n    \"einops\",\n    \"rotary-embedding-torch\",\n    \"soxr\",\n]\nlicense = \"MIT\"\nlicense-files = [\"LICENSE\"]\n\n[tool.setuptools.package-dir]\nbeat_this = \"beat_this\"\n\n[project.urls]\nRepository = \"https://github.com/CPJKU/beat_this\"\nIssues = \"https://github.com/CPJKU/beat_this/issues\"\nChangelog = \"https://github.com/CPJKU/beat_this/blob/main/CHANGELOG.md\"\n\n[project.scripts]\nbeat_this = \"beat_this.cli:main\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "# This is a set of known working versions for inference, documented for a\n# distant future.\n# We recommend following the requirements section in our README.md instead.\neinops==0.8.0\nnumpy==1.26.4\nrotary_embedding_torch==0.6.4\nsoxr==0.3.7\ntorch==2.3.1\ntorchaudio==2.3.1\ntqdm==4.66.4"
  },
  {
    "path": "tests/test_inference.py",
    "content": "from pathlib import Path\n\nimport numpy as np\nimport soundfile as sf\nimport torch\n\nfrom beat_this.inference import Audio2Frames, File2Beats\n\n\ndef test_File2Beat():\n    f2b = File2Beats()\n    audio_path = Path(\"tests/It Don't Mean A Thing - Kings of Swing.mp3\")\n    beat, downbeat = f2b(audio_path)\n    assert isinstance(beat, np.ndarray)\n    assert isinstance(downbeat, np.ndarray)\n\n\ndef test_Audio2Frames():\n    a2f = Audio2Frames()\n    audio_path = Path(\"tests/It Don't Mean A Thing - Kings of Swing.mp3\")\n    # load audio\n    audio, sr = sf.read(audio_path)\n    beat, downbeat = a2f(audio, sr)\n    assert isinstance(beat, torch.Tensor)\n    assert isinstance(downbeat, torch.Tensor)\n"
  }
]