Repository: kkoutini/PaSST Branch: main Commit: 2a5c818afcc2 Files: 52 Total size: 341.1 KB Directory structure: gitextract_blhd6h9_/ ├── .github/ │ └── workflows/ │ └── pylint.yml ├── .gitignore ├── .markdownlint.json ├── LICENSE ├── README.md ├── __init__.py ├── audioset/ │ ├── README.md │ ├── dataset.py │ └── prepare_scripts/ │ ├── convert_to_mp3.py │ └── create_h5pymp3_dataset.py ├── ba3l/ │ ├── __init__.py │ ├── experiment.py │ ├── ingredients/ │ │ ├── __init__.py │ │ ├── datasets.py │ │ ├── ingredient.py │ │ ├── models.py │ │ └── trainer.py │ ├── module.py │ ├── plutils/ │ │ ├── __init__.py │ │ ├── lr_monitor.py │ │ └── progress_bar.py │ └── util/ │ ├── __init__.py │ ├── functions.py │ └── sacred_logger.py ├── config_updates.py ├── environment.yml ├── environment_old_2021.yml ├── esc50/ │ ├── README.md │ └── dataset.py ├── ex_audioset.py ├── ex_esc50.py ├── ex_fsd50k.py ├── ex_openmic.py ├── fsd50k/ │ ├── README.md │ ├── dataset.py │ └── prepare_scripts/ │ ├── convert_to_mp3.py │ └── create_h5pymp3_dataset.py ├── helpers/ │ ├── audiodatasets.py │ ├── mixup.py │ ├── models_size.py │ ├── ramp.py │ ├── swa_callback.py │ ├── swa_legacy.py │ └── workersinit.py ├── models/ │ ├── helpers/ │ │ └── vit_helpers.py │ ├── passt.py │ └── preprocess.py ├── openmic/ │ ├── README.md │ ├── dataset.py │ └── prepare_scripts/ │ └── download_preprocess.py ├── pip_list.txt └── requirements.txt ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/pylint.yml ================================================ name: Pylint on: [push] jobs: build: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install pylint - name: Analysing the code with pylint run: | pylint $(git ls-files '*.py') ================================================ FILE: .gitignore ================================================ # project /output/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class example.py timit_data/ LJSpeech-1.1/ # C extensions *.so .idea/ # Distribution / packaging .Python env/ ide_layouts/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv venv/ ENV/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ /openmic/prepare_scripts/openmic-2018-v1.0.0/ /openmic/prepare_scripts/openmic-2018-v1.0.0.tgz /audioset_hdf5s/mp3/openmic_test.csv_mp3.hdf /audioset_hdf5s/mp3/openmic_train.csv_mp3.hdf environment_builds.yml # Output lightning_logs/* audioset_hdf5s/* .vscode/settings.json wandb/* .vscode/launch.json ================================================ FILE: .markdownlint.json ================================================ { "MD033": { "allowed_elements": [ "p", "img" ] }, "MD013": false } ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # PaSST: Efficient Training of Audio Transformers with Patchout This is the implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069) Patchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.

Patchout works by dropping out some of the input patches during training. In either an unstructured way (randomly, similar to dropout), or entire time-frames or frequency bins of the extracted patches (similar to SpecAugment), which corresponds to rows/columns in step 3 of the figure below.

## Table of contents - [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions) - [Getting the logits from the pretrained models](#getting-the-logits-from-the-pretrained-models) - [Getting a pre-trained model for fine-tuning](#getting-a-pre-trained-model-for-fine-tuning) - [Development environment](#development-environment) - [Setting up the development experiments environment](#setting-up-the-development-experiments-environment) - [Setting up using the exported conda environment](#setting-up-using-the-exported-conda-environment) - [Checking the environment](#checking-the-environment) - [Getting started](#getting-started) - [General information](#general-information) - [Configuring the experiment](#configuring-the-experiment) - [Training on Audioset](#training-on-audioset) - [Examples with Pre-trained models](#examples-with-pre-trained-models) - [Examples fine-tuning on downstream datasets](#examples-of-fine-tuning-on-downstream-datasets) - [Citation](#citation) - [Contact](#contact) ## Pre-trained models for Inference and embeddings extractions If you only want to use the embeddings generated by the pretrained models, use your own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo [here](https://github.com/kkoutini/passt_hear21). The package follows [HEAR 2021 NeurIPS Challenge](https://neuralaudio.ai/hear2021-results.html) API, and can be installed: ```shell pip install hear21passt ``` This repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks. ### Getting the logits from the pretrained models ```python from hear21passt.base import get_basic_model,get_model_passt import torch # get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer model = get_basic_model(mode="logits") print(model.mel) # Extracts mel spectrogram from raw waveforms. print(model.net) # the transformer network. # example inference model.eval() model = model.cuda() with torch.no_grad(): # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k # example audio_wave of batch=3 and 10 seconds audio = torch.ones((3, 32000 * 10))*0.5 audio_wave = audio.cuda() logits=model(audio_wave) ``` ### Getting a pre-trained model for fine tuning ```python from hear21passt.base import get_basic_model,get_model_passt import torch # get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer model = get_basic_model(mode="logits") print(model.mel) # Extracts mel spectrogram from raw waveforms. # optional replace the transformer with one that has the required number of classes i.e. 50 model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50) print(model.net) # the transformer network. # now model contains mel + the transformer pre-trained model ready to be fine tuned. # It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k model.train() model = model.cuda() ``` ## Development environment If you want to use the same environment as in the paper, you can follow the instructions below. ### Setting up the development experiments environment For training models from scratch or fine-tuning using the same setup as in the paper: 1. If needed, create a new environment with python 3.8 and activate it: ```bash conda create -n passt python=3.8 conda activate passt ``` 1. Install pytorch build that suits your system. For example: ```bash conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch ``` 1. Install the requirements: ```bash pip install -r requirements.txt ``` ### Setting up using the exported conda environment Alternatively, you can use the exported conda environment `environment.yml` to create the environment. For setting up [Mamba](https://github.com/mamba-org/mamba) is recommended since it works faster than `conda`: ```shell conda install mamba -n base -c conda-forge ``` Now you can import the environment from `environment.yml` ```shell mamba env create -f environment.yml ``` Now you have an environment named `ba3l`. ### Checking the environment In order to check if your environment matched the environment we used in our runs, please check the `environment.yml` and `pip_list.txt` files, which were exported using: ```shell conda env export --no-builds | grep -v "prefix" > environment.yml pip list > pip_list.txt ``` ## Getting started If you want to use your setup and only use the models from this repo, you can get the models train them from scratch or fine-tune them on your own dataset as explained above [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions). The rest of this section explains using this repo for training and fine-tuning the models. For that, first you need to set up the development environment as explained above. ### General information The repo is built using [sacred](https://sacred.readthedocs.io/en/) for experiment management and configuration, pytorch-lightning for training, and wandb for logging. Each dataset has a main experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder. The experiment file contains the main training and validation logic. The dataset folder contains the dataset specific code needed to download, preprocess and load the dataset for training. In general, you can prob the experiment file for help, this will print the available commands and basic options: ```shell python ex_audioset.py help ``` ### Configuring the experiment Each experiment has a set of default configuration options, defined in the experiment file, e.g. `ex_audioset.py`. You can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html). You can use the `print_config` command to print the configuration values without training a model: ```shell python ex_audioset.py print_config ``` You can use then use the command line interface to override any of the configuration options ([sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html)), using `with` e.g.: ```shell python ex_audioset.py with trainer.precision=16 ``` This will train on Audioset using 16-bit precision. The overall configurations look like this: ```yaml ... seed = 542198583 # the random seed for this experiment slurm_job_id = '' speed_test_batch_size = 100 swa = True swa_epoch_start = 50 swa_freq = 5 use_mixup = True warm_up_len = 5 weight_decay = 0.0001 basedataset: base_dir = 'audioset_hdf5s/' # base directory of the dataset, change it or make a link eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf' wavmix = 1 .... roll_conf: axis = 1 shift = None shift_range = 50 datasets: test: batch_size = 20 dataset = {CMD!}'/basedataset.get_test_set' num_workers = 16 validate = True training: batch_size = 12 dataset = {CMD!}'/basedataset.get_full_training_set' num_workers = 16 sampler = {CMD!}'/basedataset.get_ft_weighted_sampler' shuffle = None train = True models: mel: freqm = 48 timem = 192 hopsize = 320 htk = False n_fft = 1024 n_mels = 128 norm = 1 sr = 32000 ... net: arch = 'passt_s_swa_p16_128_ap476' fstride = 10 in_channels = 1 input_fdim = 128 input_tdim = 998 n_classes = 527 s_patchout_f = 4 s_patchout_t = 40 tstride = 10 u_patchout = 0 ... trainer: accelerator = None accumulate_grad_batches = 1 amp_backend = 'native' amp_level = 'O2' auto_lr_find = False auto_scale_batch_size = False ... ``` There are many things that can be updated from the command line. In short: - All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line. - `wandb` is the wandb configuration. For example, to change the wandb project `wandb.project="test_project"` to the command line. - `models.net` are the PaSST (or the chosen NN) options. Examples: `models.net.u_patchout`, `models.net.s_patchout_f` `models.net.s_patchout_t` control the unstructured patchout and structured patchout over frequency and time. `input_fdim` and `input_tdim` are the input spectrogram dimensions over frequency and time. `models.net.fstride` and `models.net.tstride` are the strides of the input patches over frequency and time, setting these to 16 means no patch overlap. - `models.mel` are the preprocessing options (mel spectrograms). `mel.sr` is the sampling rate, `mel.hopsize` is the hop size of the STFT window, `mel.n_mels` is the number of mel bins, `mel.freqm` and `mel.timem` are the frequency and time masking parameters of spec-augment. There are many pre-defined configuration bundles (called named_configs) in `config_updates.py`. These include different models, setups etc... You can list these configurations with: ```shell python ex_audioset.py print_named_configs ``` For example, `passt_s_20sec` is a configuration bundle that sets the model to PaSST-S pre-trained on Audioset, and accepts up to 20 second clips. ## Training on Audioset Download and prepare the dataset as explained in the [audioset page](audioset/) The base PaSST model can be trained for example like this: ```bash python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p ``` For example using only unstructured patchout of 400: ```bash python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 models.net.u_patchout=400 models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p ``` Multi-gpu training can be enabled by setting the environment variable `DDP`, for example with 2 gpus: ```shell DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -c "PaSST base 2 GPU" ``` ## Examples with Pre-trained models Please check the [releases page](https://github.com/kkoutini/PaSST/releases/), to download pre-trained models. In general, you can get a pretrained model on Audioset using ```python from models.passt import get_model model = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1, fstride=10, tstride=10,input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=40, s_patchout_f=4) ``` this will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs. There are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`. For example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479. In general, you can get a pretrained model using: ```python from models.passt import get_model passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16) ``` Using the framework, you can evaluate this model using: ```shell python ex_audioset.py evaluate_only with trainer.precision=16 passt_s_swa_p16_s16_128_ap473 -p ``` Ensemble of these models are provided as well: A large ensemble giving `mAP=.4956` ```shell python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many ``` An ensemble of 2 models with `stride=14` and `stride=16` giving `mAP=.4858` ```shell python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14 ``` As well as other ensembles `ensemble_4`, `ensemble_5` ## Examples of fine-tuning on downstream datasets 1. [ESC-50: Dataset for Environmental Sound Classification](esc50/) 2. [OpenMIC-2018 dataset](openmic/) 3. [FSD50K](fsd50k/) ## Citation The citation to the accepted paper in Interspeech 2022: ```bib @inproceedings{koutini22passt, author = {Khaled Koutini and Jan Schl{\"{u}}ter and Hamid Eghbal{-}zadeh and Gerhard Widmer}, title = {Efficient Training of Audio Transformers with Patchout}, booktitle = {Interspeech 2022, 23rd Annual Conference of the International Speech Communication Association, Incheon, Korea, 18-22 September 2022}, pages = {2753--2757}, publisher = {{ISCA}}, year = {2022}, url = {https://doi.org/10.21437/Interspeech.2022-227}, doi = {10.21437/Interspeech.2022-227}, } ``` ## Contact The repo will be updated, in the meantime if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly. ================================================ FILE: __init__.py ================================================ ================================================ FILE: audioset/README.md ================================================ # Experiments on Audioset Audioset has around 2M segments. The total size of the dataset with wav files with `32khz` sampling rate is around 1.2 TB. In our setup, this results in a huge IO bottleneck that slows down the training process significantly. Therefore, we encode the dataset to mp3, pack the mp3 into HDF5 format and decode the mp3s on the fly, If you have enough cpu cores (10-16 dataloading workers) you should not notice any slowdowns. in the `dataset.py` file we read the samples from the hdf files. Decode the mp3, do wave form augmentations and return the raw waveform of the model. `AudioSetDataset` is the main class where reading from the hdf files. ## Preparing the dataset ### Downloading Audioset We used the scripts provided by [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn) to download the dataset. ### Converting to mp3 Once the Datasets are downloaded we convert all the files to mp3 using the script: `prepare_scripts/convert_to_mp3.py`. ```bash python convert_to_mp3.py --source pann_download_folder --out mp3_folder ``` this will significantly reduce the size of the dataset and overcome the IO bottleneck in our setup. The trade-off is that more cpu is needed during training to decode the mp3s. We use the [av](https://pypi.org/project/av/) (check `decode_mp3` in ` dataset.py`) library to decode the mp3 in the data loading workers, this is much faster than calling ffmpeg. As a result, approximetly 10 decoding threads should be enough keep a 2080ti busy. you can test how much time it take to load and decode one epoch on your system: ```bash python python ex_audioset.py test_loaders_train_speed ``` This step is not necessary if you have a more powerful setup and the `decode_mp3` also supports other ffmpeg codecs. ### packing to HDF5 files Finally, you need to pack the mp3 files into a single HDF5 file using `create_h5pymp3_dataset.py`. you just need to set the paths in the script to match your local paths. The script goes through the csv files and check if the corresponding mp3 file exists, then it will store it in h5py file. The output of this step should be 3 files `balanced_train_segments_mp3.hdf`, `eval_segments_mp3.hdf` and `unbalanced_train_segments_mp3.hdf`. Each of these files. Make sure the paths match the default config in `dataset.py` ================================================ FILE: audioset/dataset.py ================================================ import io import os import pathlib import random import av import librosa import torchaudio from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler import torch from ba3l.ingredients.datasets import Dataset from sacred.config import DynamicIngredient, CMD from scipy.signal import convolve import numpy as np from helpers.audiodatasets import PreprocessDataset import h5py LMODE = os.environ.get("LMODE", False) #$TMPDIR dataset = Dataset('audiodataset') @dataset.config def default_config(): name = 'audioset' # dataset name normalize = False # normalize dataset subsample = False # subsample squares from the dataset roll = True # apply roll augmentation fold = 1 base_dir = "audioset_hdf5s/" # base directory of the dataset, change it or make a link if LMODE: base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/" balanced_train_hdf5 = base_dir + "mp3/balanced_train_segments_mp3.hdf" eval_hdf5 = base_dir + "mp3/eval_segments_mp3.hdf" unbalanced_train_hdf5 = base_dir + "mp3/unbalanced_train_segments_mp3.hdf" if LMODE: balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/") unbalanced_train_hdf5 = unbalanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/") eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/") ir_path = base_dir + "irs/" num_of_classes = 527 if LMODE: @dataset.config def LMODE_default_config(): cache_root_path = "/system/user/publicdata/CP/DCASE/cached_datasets/" def decode_mp3(mp3_arr): """ decodes an array if uint8 representing an mp3 file :rtype: np.array """ container = av.open(io.BytesIO(mp3_arr.tobytes())) stream = next(s for s in container.streams if s.type == 'audio') # print(stream) a = [] for i, packet in enumerate(container.demux(stream)): for frame in packet.decode(): a.append(frame.to_ndarray().reshape(-1)) waveform = np.concatenate(a) if waveform.dtype != 'float32': raise RuntimeError("Unexpected wave type") return waveform def pad_or_truncate(x, audio_length): """Pad all audio to specific length.""" if len(x) <= audio_length: return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) else: return x[0: audio_length] irs_arr = None @dataset.command def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None): if not ir_augment: return global irs_arr if irs_arr is None: all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')] all_paths = sorted(all_paths) if cut_irs_offset is not None: all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10] all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths] print("will use these IRs:") for i in range(len(all_paths_name)): print(i, ": ", all_paths_name[i]) _run.info["ir_devices"] = all_paths_name irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths] return irs_arr[int(np.random.randint(0, len(irs_arr)))] @dataset.command def pydub_augment(waveform, gain_augment=7, ir_augment=0): if ir_augment and torch.rand(1) < ir_augment: ir = get_ir_sample() waveform = convolve(waveform, ir, 'full') if gain_augment: gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment amp = 10 ** (gain / 20) waveform = waveform * amp return waveform class MixupDataset(TorchDataset): """ Mixing Up wave forms """ def __init__(self, dataset, beta=2, rate=0.5): self.beta = beta self.rate = rate self.dataset = dataset print(f"Mixing up waveforms from dataset of len {len(dataset)}") def __getitem__(self, index): if torch.rand(1) < self.rate: x1, f1, y1 = self.dataset[index] idx2 = torch.randint(len(self.dataset), (1,)).item() x2, f2, y2 = self.dataset[idx2] l = np.random.beta(self.beta, self.beta) l = max(l, 1. - l) x1 = x1-x1.mean() x2 = x2-x2.mean() x = (x1 * l + x2 * (1. - l)) x = x - x.mean() return x, f1, (y1 * l + y2 * (1. - l)) return self.dataset[index] def __len__(self): return len(self.dataset) class AudioSetDataset(TorchDataset): def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False): """ Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav """ self.sample_rate = sample_rate self.hdf5_file = hdf5_file if in_mem: print("\nPreloading in memory\n") with open(hdf5_file, 'rb') as f: self.hdf5_file = io.BytesIO(f.read()) with h5py.File(hdf5_file, 'r') as f: self.length = len(f['audio_name']) print(f"Dataset from {hdf5_file} with length {self.length}.") self.dataset_file = None # lazy init self.clip_length = clip_length * sample_rate self.classes_num = classes_num self.augment = augment if augment: print(f"Will agument data from {hdf5_file}") def open_hdf5(self): self.dataset_file = h5py.File(self.hdf5_file, 'r') def __len__(self): return self.length def __del__(self): if self.dataset_file is not None: self.dataset_file.close() self.dataset_file = None def __getitem__(self, index): """Load waveform and target of an audio clip. Args: meta: { 'hdf5_path': str, 'index_in_hdf5': int} Returns: data_dict: { 'audio_name': str, 'waveform': (clip_samples,), 'target': (classes_num,)} """ if self.dataset_file is None: self.open_hdf5() audio_name = self.dataset_file['audio_name'][index].decode() waveform = decode_mp3(self.dataset_file['mp3'][index]) if self.augment: waveform = pydub_augment(waveform) waveform = pad_or_truncate(waveform, self.clip_length) waveform = self.resample(waveform) target = self.dataset_file['target'][index] target = np.unpackbits(target, axis=-1, count=self.classes_num).astype(np.float32) return waveform.reshape(1, -1), audio_name, target def resample(self, waveform): """Resample. Args: waveform: (clip_samples,) Returns: (resampled_clip_samples,) """ if self.sample_rate == 32000: return waveform elif self.sample_rate == 16000: return waveform[0:: 2] elif self.sample_rate == 8000: return waveform[0:: 4] else: raise Exception('Incorrect sample rate!') @dataset.command def get_base_training_set(balanced_train_hdf5): ds = AudioSetDataset(balanced_train_hdf5, augment=True) return ds @dataset.command def get_unbalanced_training_set(unbalanced_train_hdf5): ds = AudioSetDataset(unbalanced_train_hdf5, augment=True) return ds @dataset.command def get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5): ds = ConcatDataset( [AudioSetDataset(balanced_train_hdf5, augment=False), AudioSetDataset(unbalanced_train_hdf5, augment=False)]) return ds @dataset.command def get_base_full_training_set(): sets = [get_base_training_set(), get_unbalanced_training_set()] ds = ConcatDataset(sets) return ds @dataset.command def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes): for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]: print(f"\n \n will now preload {hdf5_file} \n\n ") with h5py.File(hdf5_file, 'r') as dataset_file: target = dataset_file['mp3'][:] print(len(target)) print(f"\n \n done with {hdf5_file} \n\n ") return target[1000] @dataset.command def get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes, sample_weight_offset=100, sample_weight_sum=True): """ :return: float tenosr of shape len(full_training_set) representing the weights of each sample. """ # the order of balanced_train_hdf5,unbalanced_train_hdf5 is important. # should match get_full_training_set all_y = [] for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]: with h5py.File(hdf5_file, 'r') as dataset_file: target = dataset_file['target'] target = np.unpackbits(target, axis=-1, count=num_of_classes) all_y.append(target) all_y = np.concatenate(all_y, axis=0) all_y = torch.as_tensor(all_y) per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class per_class = sample_weight_offset + per_class # offset low freq classes if sample_weight_offset > 0: print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}") per_class_weights = 1000. / per_class all_weight = all_y * per_class_weights # print(all_weight.shape) # print(all_weight[1510]) if sample_weight_sum: print("\nsample_weight_sum\n") all_weight = all_weight.sum(dim=1) else: all_weight, _ = all_weight.max(dim=1) # print(all_weight.shape) # print(all_weight[1510]) return all_weight @dataset.command def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"), epoch_len=100000, sampler_replace=False): num_nodes = int(os.environ.get('num_nodes', 1)) ddp = int(os.environ.get('DDP', 1)) num_nodes = max(ddp, num_nodes) print("num_nodes= ", num_nodes) rank = int(os.environ.get('NODE_RANK', 0)) return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights, num_samples=epoch_len, replacement=sampler_replace), dataset=range(epoch_len), num_replicas=num_nodes, rank=rank, ) @dataset.command def get_base_test_set(eval_hdf5): ds = AudioSetDataset(eval_hdf5) return ds @dataset.command(prefix='roll_conf') def get_roll_func(axis=1, shift=None, shift_range=50): print("rolling...") def roll_func(b): x, i, y = b x = torch.as_tensor(x) sf = shift if shift is None: sf = int(np.random.random_integers(-shift_range, shift_range)) global FirstTime return x.roll(sf, axis), i, y return roll_func @dataset.command def get_training_set(normalize, roll, wavmix=False): ds = get_base_training_set() get_ir_sample() if normalize: print("normalized train!") fill_norms() ds = PreprocessDataset(ds, norm_func) if roll: ds = PreprocessDataset(ds, get_roll_func()) if wavmix: ds = MixupDataset(ds) return ds @dataset.command def get_full_training_set(normalize, roll, wavmix=False): ds = get_base_full_training_set() get_ir_sample() if normalize: print("normalized train!") fill_norms() ds = PreprocessDataset(ds, norm_func) if roll: ds = PreprocessDataset(ds, get_roll_func()) if wavmix: ds = MixupDataset(ds) return ds @dataset.command def get_test_set(normalize): ds = get_base_test_set() if normalize: print("normalized test!") fill_norms() ds = PreprocessDataset(ds, norm_func) return ds @dataset.command def print_conf(_config): print("Config of ", dataset.path, id(dataset)) print(_config) print() class DistributedSamplerWrapper(DistributedSampler): def __init__( self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True): super(DistributedSamplerWrapper, self).__init__( dataset, num_replicas, rank, shuffle) # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238 self.sampler = sampler def __iter__(self): if self.sampler.generator is None: self.sampler.generator = torch.Generator() self.sampler.generator.manual_seed(self.seed + self.epoch) indices = list(self.sampler) if self.epoch == 0: print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n") indices = indices[self.rank:self.total_size:self.num_replicas] return iter(indices) if __name__ == "__main__": from sacred import Experiment ex = Experiment("test_dataset", ingredients=[dataset]) @ex.automain def default_command(): ex.current_run.get_command_function("print_config")() get_base_training_set() ds = get_test_set() print(ds[0]) ds = get_training_set() print(ds[0]) print("get_base_training_set", len(get_base_training_set())) print("get_base_test_set", len(get_base_test_set())) print("get_training_set", len(get_training_set())) print("get_test_set", len(get_test_set())) ================================================ FILE: audioset/prepare_scripts/convert_to_mp3.py ================================================ import argparse import multiprocessing import glob import os # Replace this with the dataset downloaded using PANN scripts source_path = "/share/cp/datasets/full_audioset/audioset201906/audios/" # Replace with the output directory out_path = "./mp3_audio/" all_num = 0 def process_folder(fol="balanced_train_segments"): print("now working on ", fol) os.makedirs(out_path + fol, exist_ok=True) all_files = list(glob.glob(source_path + fol + "/*.wav")) print(f"it has {len(all_files)}") global all_num all_num = len(all_files) cmds = [(i, file, out_path + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)] print(cmds[0]) with multiprocessing.Pool(processes=20) as pool: pool.starmap(process_one, cmds) def process_one(i, f1, f2): if i % 100 == 0: print(f"{i}/{all_num} \t", f1) os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--source', type=str, required=False, default=None, help='Path of folder containing the wave files, expected format to be as downloaded ' 'using the PANNs script, containing balanced_train_segments, eval_segments, ' 'unbalanced_train_segments folders.') parser.add_argument('--out', type=str, required=False, default=None, help='Directory to save out the converted mp3s.') args = parser.parse_args() source_path = args.source or source_path out_path = args.out or out_path folders = ['balanced_train_segments', 'eval_segments'] + ["unbalanced_train_segments/" + x for x in sorted( os.listdir(source_path + "unbalanced_train_segments/"))] print("I will work on these folders:") print(folders) for fol in folders: process_folder(fol) ================================================ FILE: audioset/prepare_scripts/create_h5pymp3_dataset.py ================================================ # %% import h5py import pandas as pd import numpy as np import csv import os # %% base_dir = "../../audioset_hdf5s/" balanced_csv= base_dir+ "metadata/balanced_train_segments.csv" eval_csv= base_dir+ "metadata/eval_segments.csv" mp3_path = "../../mp3_audio/" # %% def read_metadata(csv_path, classes_num, id_to_ix): """Read metadata of AudioSet from a csv file. source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59 Args: csv_path: str Returns: meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} """ with open(csv_path, 'r') as fr: lines = fr.readlines() lines = lines[3:] # Remove heads audios_num = len(lines) targets = np.zeros((audios_num, classes_num), dtype=np.bool) audio_names = [] for n, line in enumerate(lines): items = line.split(', ') """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" audio_name = 'Y{}.mp3'.format(items[0]) # Audios are started with an extra 'Y' when downloading label_ids = items[3].split('"')[1].split(',') audio_names.append(audio_name) # Target for id in label_ids: ix = id_to_ix[id] targets[n, ix] = 1 meta_dict = {'audio_name': np.array(audio_names), 'target': targets} return meta_dict # Load label with open(base_dir+'metadata/class_labels_indices.csv', 'r') as f: reader = csv.reader(f, delimiter=',') lines = list(reader) labels = [] ids = [] # Each label has a unique id such as "/m/068hy" for i1 in range(1, len(lines)): id = lines[i1][1] label = lines[i1][2] ids.append(id) labels.append(label) classes_num = len(labels) lb_to_ix = {label : i for i, label in enumerate(labels)} ix_to_lb = {i : label for i, label in enumerate(labels)} id_to_ix = {id : i for i, id in enumerate(ids)} ix_to_id = {i : id for i, id in enumerate(ids)} # %% def check_available(balanced_csv,balanced_audio_path,prefix=None): meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix) audios_num = len(meta_csv['audio_name']) found=0 notfound=0 available_files=[] available_targets=[] if prefix is None: prefix = os.path.basename(balanced_csv)[:-4] for n in range(audios_num): audio_path = meta_csv['audio_name'][n] #print(balanced_audio_path + f"{prefix}/{audio_path}") if os.path.isfile(balanced_audio_path + f"{prefix}/{audio_path}" ): found+=1 available_files.append(meta_csv['audio_name'][n]) available_targets.append(meta_csv['target'][n]) else: notfound+=1 print(f"Found {found} . not found {notfound}") return available_files,available_targets # %% # %% # %% for read_file,prefix in [(balanced_csv,"balanced_train_segments/"), (eval_csv,"eval_segments/"),]: print("now working on ",read_file,prefix) #files, y = torch.load(read_file+".pth") files, y = check_available(read_file, mp3_path) y = np.packbits(y, axis=-1) packed_len = y.shape[1] print(files[0], "classes: ",packed_len, y.dtype) available_size = len(files) f = files[0][:-3]+"mp3" a = np.fromfile(mp3_path+prefix + "/"+f, dtype='uint8') dt = h5py.vlen_dtype(np.dtype('uint8')) save_file = prefix.split("/")[0] with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf: audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20') waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt) target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype) for i,file in enumerate(files): if i%1000==0: print(f"{i}/{available_size}") f = file[:-3] + "mp3" a = np.fromfile(mp3_path + prefix + f, dtype='uint8') audio_name[i]=f waveform[i] = a target[i] = y[i] print(a.shape) print("Done!" , prefix) # %% print("working on unbalanced...") all_x,all_y = None, None for idx in range(41): print("working on ",idx) tmp_csv = base_dir+ f"metadata/unbalanced_partial_csvs/unbalanced_train_segments_part{idx:02}.csv" prefix = f"unbalanced_train_segments/unbalanced_train_segments_part{idx:02}" x,y = check_available(tmp_csv,mp3_path,prefix=prefix) x = np.array([f"{idx:02}/"+one for one in x]) y=np.packbits(y, axis=-1) print("x,y",x.shape, y.shape) if all_x is None: all_x = x all_y = y else: all_x = np.concatenate((all_x,x)) all_y = np.concatenate((all_y,y)) print(f"done {idx}! all x,y",all_x.shape, all_y.shape) print("now working on packing unbalanced") prefix = "unbalanced_train_segments/unbalanced_train_segments_part" files = all_x y = all_y packed_len = y.shape[1] print(files[0], "classes: ",packed_len, y.dtype) available_size = len(files) f = files[0][:-3]+"mp3" a = np.fromfile(mp3_path+prefix + f, dtype='uint8') dt = h5py.vlen_dtype(np.dtype('uint8')) save_file = prefix.split("/")[0] with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf: audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20') waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt) target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype) for i,file in enumerate(files): if i%1000==0: print(f"{i}/{available_size}") f = file[:-3] + "mp3" a = np.fromfile(mp3_path + prefix + f, dtype='uint8') audio_name[i]=f waveform[i] = a target[i] = y[i] print(a.shape) print("Done!" , prefix) ================================================ FILE: ba3l/__init__.py ================================================ """Package info""" __version__ = "0.0.2" __author__ = "Koutini et al." __license__ = "Apache-2.0" __copyright__ = "Copyright (c) 2019-, %s." % __author__ __homepage__ = "https://github.com//" __docs__ = ( "The researcher friendly pytorch environment. Ba3l= sacred+pytorch-lightening." ) __author_email__ = "first.last@jku.at" ================================================ FILE: ba3l/experiment.py ================================================ import inspect from importlib import import_module from ba3l.ingredients.datasets import Datasets from ba3l.ingredients.models import Models, Model #from ba3l.trainer import Trainer from ba3l.util.sacred_logger import SacredLogger from sacred import Experiment as Sacred_Experiment, Ingredient from typing import Sequence, Optional, List from sacred.commandline_options import CLIOption from sacred.config import CMD from sacred.host_info import HostInfoGetter from sacred.utils import PathType, optional_kwargs_decorator from pytorch_lightning import loggers as pl_loggers from ba3l.util.functions import get_default_kwargs_dict def ingredients_recursive_apply(ing, fn): fn(ing) for kid in ing.ingredients: ingredients_recursive_apply(kid, fn) def config_recursive_apply(conf, fn): for k,v in conf.items(): if isinstance(v, dict): config_recursive_apply(v,fn) else: fn(k,v) def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False): loggers = [] if use_sacred_logger: loggers.append( SacredLogger(expr)) if use_tensorboard_logger: loggers.append(pl_loggers.TensorBoardLogger(sacred_logger.name)) return loggers class Experiment(Sacred_Experiment): """ Main Ba3l Experiment class overrides sacred experiments. """ def __init__( self, name: Optional[str] = None, ingredients: Sequence[Ingredient] = (), datasets: Optional[Ingredient] = None, models: Optional[Ingredient] = None, interactive: bool = False, base_dir: Optional[PathType] = None, additional_host_info: Optional[List[HostInfoGetter]] = None, additional_cli_options: Optional[Sequence[CLIOption]] = None, save_git_info: bool = True, ): """ Create a new experiment with the given name and optional ingredients. (from Sacred) Parameters ---------- name Optional name of this experiment, defaults to the filename. (Required in interactive mode) ingredients : list[sacred.Ingredient], optional A list of ingredients to be used with this experiment. interactive If set to True will allow the experiment to be run in interactive mode (e.g. IPython or Jupyter notebooks). However, this mode is discouraged since it won't allow storing the source-code or reliable reproduction of the runs. base_dir Optional full path to the base directory of this experiment. This will set the scope for automatic source file discovery. additional_host_info Optional dictionary containing as keys the names of the pieces of host info you want to collect, and as values the functions collecting those pieces of information. save_git_info: Optionally save the git commit hash and the git state (clean or dirty) for all source files. This requires the GitPython package. """ if models is None: models = Models.get_instance() self.models = models if datasets is None: datasets = Datasets.get_instance() self.datasets = datasets if ingredients is None: ingredients = [] ingredients = list(ingredients) + [models, datasets] caller_globals = inspect.stack()[1][0].f_globals self.last_default_configuration_position = 0 super().__init__( name=name, ingredients=ingredients, interactive=interactive, base_dir=base_dir, additional_host_info=additional_host_info, additional_cli_options=additional_cli_options, save_git_info=save_git_info, caller_globals=caller_globals ) def get_run_identifier(self): return str(self.current_run.db_identifier) \ + "_" + str(self.current_run._id) def get_dataloaders(self, filter={}): results = {} for ds in self.datasets.get_datasets(filter): results[ds.name] = ds.get_iterator() if len(results) == 1: for k, v in results.items(): return v return results def get_train_dataloaders(self): return self.get_dataloaders(dict(train=True)) def get_val_dataloaders(self): return self.get_dataloaders(dict(validate=True)) def _create_run( self, command_name=None, config_updates=None, named_configs=(), info=None, meta_info=None, options=None, dry_run=False, ): if self.current_run is not None: # @todo replace with logger print("Warning: multiple runs are not yet supported") run = super()._create_run( command_name, config_updates, named_configs, info, meta_info, options, dry_run=False, ) # self.current_run=run # def update_current_run(ing): # ing.current_run = run # # ingredients_recursive_apply(self, update_current_run) return run @optional_kwargs_decorator def command( self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={}, **extra_args ): """ Decorator to define a new Command. a command is a function whose parameters are filled automatically by sacred. The command can be given a prefix, to restrict its configuration space to a subtree. (see ``capture`` for more information) A command can be made unobserved (i.e. ignoring all observers) by passing the unobserved=True keyword argument. :param add_default_args_config: wether to add the default arguments of the function to the config automatically. :param function: the function to return a Dataset Object :param prefix: sacred configuration prefix :param unobserved: sacred unobserved :param static_args: static Args to be passed to the function, these arg need not to be serlizable and are not stored in the config :param extra_args: explicit arguments to be add to the config, you can these to override the function default values, for example wraping a config with CMD, then the parameter will be filled with excuting the command specified by CMD string value. CMD string have special context :return: """ add_default_args_config = (not unobserved) and add_default_args_config if add_default_args_config: self.add_default_args_config(function, prefix, extra_args, static_args=static_args) captured_f = self.capture(function, prefix=prefix, static_args=static_args) captured_f.unobserved = unobserved self.commands[function.__name__] = captured_f return captured_f def add_default_args_config(self, function, prefix, extra_args={}, static_args={}): """ adds the default parameters of a function to the ingredient config at lowest priority! Default args config is meant remove the need to declare all the configurations manually. :param f: the function """ # @todo get the doc of the params as well config_candidate = {**get_default_kwargs_dict(function), **extra_args} # remove "static_args" from config for k in static_args: config_candidate.pop(k, None) # respect the prefix for the added default parameters if prefix is not None: for pr in prefix.split('.')[::-1]: config_candidate={pr: config_candidate} self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None)) self.last_default_configuration_position += 1 ================================================ FILE: ba3l/ingredients/__init__.py ================================================ ================================================ FILE: ba3l/ingredients/datasets.py ================================================ import inspect import os from functools import partial from ba3l.util.functions import get_default_kwargs_dict from sacred.config import CMD from .ingredient import Ingredient from typing import Sequence, Optional, List from sacred.utils import PathType, optional_kwargs_decorator from munch import DefaultFactoryMunch, Munch def raise_(ex): raise ex class Dataset(Ingredient): """ The class that annotates a Dateset of Ba3l experiment a Dataset can be """ DATASET_STRING_PREFIX = "get_dataset" ITER_STRING_PREFIX = "get_iterator" def __init__( self, name: str, ingredients: Sequence[Ingredient] = (), interactive: bool = False, base_dir: Optional[PathType] = None, save_git_info: bool = True, ): """ Create a new experiment with the given name and optional ingredients. Parameters ---------- name Optional name of this experiment, defaults to the filename. (Required in interactive mode) ingredients : list[sacred.Ingredient], optional A list of ingredients to be used with this experiment. interactive If set to True will allow the experiment to be run in interactive mode (e.g. IPython or Jupyter notebooks). However, this mode is discouraged since it won't allow storing the source-code or reliable reproduction of the runs. base_dir Optional full path to the base directory of this experiment. This will set the scope for automatic source file discovery. additional_host_info Optional dictionary containing as keys the names of the pieces of host info you want to collect, and as values the functions collecting those pieces of information. save_git_info: Optionally save the git commit hash and the git state (clean or dirty) for all source files. This requires the GitPython package. """ caller_globals = inspect.stack()[1][0].f_globals if name is None: name = "dataset" self.name = name.rsplit(".", 1)[-1] super().__init__( path=name, ingredients=ingredients, interactive=interactive, base_dir=base_dir, _caller_globals=caller_globals, save_git_info=save_git_info, ) self.get_dataset_command = None self.get_dataset_iterator_command = None self.current_run = None self.get_dataset = lambda: raise_( NotImplementedError( "Use dataset.dataset_name.dataset to annotate the " "get_dataset function!." ) ) self.get_iter = lambda: raise_( NotImplementedError( "Use dataset.dataset_name.iter to annotate the " "get_iter function!." ) ) @optional_kwargs_decorator def dataset( self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args ): """ Decorator to define a new Dataset. The name of the dataset is used to get an instance of the dataset, it will register a command Datasets are sacred commands. The command can be given a prefix, to restrict its configuration space to a subtree. (see ``capture`` for more information) A command can be made unobserved (i.e. ignoring all observers) by passing the unobserved=True keyword argument. :param function: the function to return a Dataset Object :param prefix: sacred configuration prefix :param unobserved: sacred unobserved :param static_args: static Args to be passed to the function, these arg need not to be serlizable and are not stored in the config :param extra_args: explicit arguments to be add to the config, you can these to override the function default values, for example wraping a config with CMD, then the parameter will be filled with excuting the command specified by CMD string value. CMD string have special context :return: """ self.add_default_args_config(function, prefix, extra_args, static_args=static_args) captured_f = self.capture(function, prefix=prefix, static_args=static_args) captured_f.unobserved = unobserved self.commands[Dataset.DATASET_STRING_PREFIX] = captured_f self.get_dataset = captured_f self.add_config(dataset=CMD("get_dataset")) return captured_f @optional_kwargs_decorator def iter( self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args ): """ Decorator to define a new Iterator. The name of the iterator is used to get an instance of the iterator, it will register a command iterator are sacred commands. The command can be given a prefix, to restrict its configuration space to a subtree. (see ``capture`` for more information) A command can be made unobserved (i.e. ignoring all observers) by passing the unobserved=True keyword argument. :param function: the function to return a Dataset Object :param prefix: sacred configuration prefix :param unobserved: sacred unobserved :param static_args: static Args to be passed to the function, these arg need not to be serlizable and are not stored in the config :param extra_args: explicit arguments to be add to the config, you can these to override the function default values, for example wraping a config with CMD, then the parameter will be filled with excuting the command specified by CMD string value. CMD string have special context """ self.add_default_args_config(function, prefix, extra_args, static_args=static_args) captured_f = self.capture(function, prefix=prefix, static_args=static_args) captured_f.unobserved = unobserved self.commands[Dataset.ITER_STRING_PREFIX] = captured_f self.get_iter = captured_f return captured_f # def get_dataset(self): # assert self.current_run is not None, "Can only be called during a run." # return self.commands[Datasets.DATASET_STRING_PREFIX + name]() # # return self.current_run.get_command_function( # # self.path + "." + Datasets.DATASET_STRING_PREFIX + name)() # # def __getattr__(self, k): if k == "iterator": return self.__getattribute__("iter") if k == "get_iterator": return self.__getattribute__("get_iter") super().__getattribute__(k) # @todo maybe run commands from here after running class Datasets(Ingredient, Munch): """ The class that encapsulates all the datasets in an experiment """ __instance = None @classmethod def get_instance(cls): if Datasets.__instance is None: Datasets.__instance = Datasets() return Datasets.__instance def __init__( self, name: Optional[str] = None, ingredients: Sequence[Ingredient] = (), interactive: bool = False, base_dir: Optional[PathType] = None, save_git_info: bool = True, ): """ Create a new experiment with the given name and optional ingredients. Parameters ---------- name Optional name of this experiment, defaults to the filename. (Required in interactive mode) ingredients : list[sacred.Ingredient], optional A list of ingredients to be used with this experiment. interactive If set to True will allow the experiment to be run in interactive mode (e.g. IPython or Jupyter notebooks). However, this mode is discouraged since it won't allow storing the source-code or reliable reproduction of the runs. base_dir Optional full path to the base directory of this experiment. This will set the scope for automatic source file discovery. additional_host_info Optional dictionary containing as keys the names of the pieces of host info you want to collect, and as values the functions collecting those pieces of information. save_git_info: Optionally save the git commit hash and the git state (clean or dirty) for all source files. This requires the GitPython package. """ caller_globals = inspect.stack()[1][0].f_globals if name is None: name = "datasets" Ingredient.__init__( self, path=name, ingredients=ingredients, interactive=interactive, base_dir=base_dir, _caller_globals=caller_globals, save_git_info=save_git_info, ) self.get_datasets_list_command = None self.get_dataset_command = None self.get_dataset_iterator_command = None self.current_run = None self.get_dataset = None # self.command(get_dataset_iterator_command, unobserved=True) def __getattr__(self, k): """ Gets key if it exists, otherwise returns the default value.""" try: return Munch.__getattr__(self, k) except AttributeError: return self.__getitem__(k) def __setattr__(self, k, v): try: # Throws exception if not in prototype chain object.__getattribute__(self, k) except AttributeError: try: self[k] = v except: raise AttributeError(k) else: object.__setattr__(self, k, v) def __getitem__(self, k): """ Gets key if it exists, otherwise returns the default value.""" try: return Munch.__getitem__(self, k) except KeyError: self[k] = Dataset( self.path + "." + k, base_dir=self.base_dir, save_git_info=self.save_git_info, ) assert self self.ingredients.append(self[k]) return self[k] def __hash__(self): return Ingredient.__hash__(self) def get_datasets(self, config_conditions={}, return_datasets_names=False): """ Return all the datasets whose configuration matches config_conditions.""" results = [] for dataset in self.ingredients: all_ok = True for cond_k, cond_v in config_conditions.items(): if ( self.current_run.get_config_path_value(dataset.path + "." + cond_k) != cond_v ): all_ok = False break if all_ok: if return_datasets_names: results.append((dataset)) results.append(dataset) return results ================================================ FILE: ba3l/ingredients/ingredient.py ================================================ import inspect import os from functools import partial from ba3l.util.functions import get_default_kwargs_dict from sacred import Ingredient as sacred_Ingredient from typing import Sequence, Optional, List from sacred.utils import PathType, optional_kwargs_decorator from munch import DefaultFactoryMunch, Munch def raise_(ex): raise ex class Ingredient(sacred_Ingredient): """ The class that annotates a Dateset of Ba3l experiment a Dataset can be """ def __init__( self, path: str, ingredients: Sequence[sacred_Ingredient] = (), interactive: bool = False, _caller_globals: Optional[dict] = None, base_dir: Optional[PathType] = None, save_git_info: bool = True, ): """ The Base Ingredient of all Ba3l ingredients Parameters ---------- path Optional name of this experiment, defaults to the filename. (Required in interactive mode) ingredients : list[sacred.Ingredient], optional A list of ingredients to be used with this experiment. interactive If set to True will allow the experiment to be run in interactive mode (e.g. IPython or Jupyter notebooks). However, this mode is discouraged since it won't allow storing the source-code or reliable reproduction of the runs. base_dir Optional full path to the base directory of this experiment. This will set the scope for automatic source file discovery. additional_host_info Optional dictionary containing as keys the names of the pieces of host info you want to collect, and as values the functions collecting those pieces of information. save_git_info: Optionally save the git commit hash and the git state (clean or dirty) for all source files. This requires the GitPython package. """ _caller_globals = _caller_globals or inspect.stack()[1][0].f_globals if path is None: path = "Ingredient" super().__init__( path=path, ingredients=ingredients, interactive=interactive, base_dir=base_dir, _caller_globals=_caller_globals, save_git_info=save_git_info, ) self.current_run = None self.last_default_configuration_position = 0 def add_default_args_config(self, function, prefix, extra_args={}, static_args={}): """ adds the default parameters of a function to the ingredient config at lowest priority! Default args config is meant remove the need to declare all the configurations manually. :param f: the function """ # @todo get the doc of the params as well config_candidate = {**get_default_kwargs_dict(function), **extra_args} # remove "static_args" from config for k in static_args: config_candidate.pop(k, None) if prefix is not None: for pr in prefix.split('.')[::-1]: config_candidate={pr: config_candidate} self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None)) self.last_default_configuration_position += 1 @optional_kwargs_decorator def command( self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={}, **extra_args ): """ Decorator to define a new Command. a command is a function whose parameters are filled automatically by sacred. The command can be given a prefix, to restrict its configuration space to a subtree. (see ``capture`` for more information) A command can be made unobserved (i.e. ignoring all observers) by passing the unobserved=True keyword argument. :param function: the function to return a Dataset Object :param prefix: sacred configuration prefix :param unobserved: sacred unobserved :param static_args: static Args to be passed to the function, these arg need not to be serlizable and are not stored in the config :param extra_args: explicit arguments to be add to the config, you can these to override the function default values, for example wraping a config with CMD, then the parameter will be filled with excuting the command specified by CMD string value. CMD string have special context :return: """ if add_default_args_config: self.add_default_args_config(function, prefix, extra_args, static_args=static_args) captured_f = self.capture(function, prefix=prefix, static_args=static_args) captured_f.unobserved = unobserved self.commands[function.__name__] = captured_f return captured_f ================================================ FILE: ba3l/ingredients/models.py ================================================ import inspect import os from .ingredient import Ingredient from typing import Sequence, Optional, List from sacred.utils import PathType import inspect import os from functools import partial from ba3l.util.functions import get_default_kwargs_dict from sacred.config import CMD from .ingredient import Ingredient from typing import Sequence, Optional, List from sacred.utils import PathType, optional_kwargs_decorator from munch import DefaultFactoryMunch, Munch class Models(Ingredient): """ The class that annotates the models of Ba3l experiment """ __instance = None @classmethod def get_instance(cls): if Models.__instance is None: Models.__instance = Models() return Models.__instance def __init__( self, name: Optional[str] = None, ingredients: Sequence[Ingredient] = (), interactive: bool = False, base_dir: Optional[PathType] = None, save_git_info: bool = True, ): """ Create a new experiment with the given name and optional ingredients. Parameters ---------- name Optional name of this experiment, defaults to the filename. (Required in interactive mode) ingredients : list[sacred.Ingredient], optional A list of ingredients to be used with this experiment. interactive If set to True will allow the experiment to be run in interactive mode (e.g. IPython or Jupyter notebooks). However, this mode is discouraged since it won't allow storing the source-code or reliable reproduction of the runs. base_dir Optional full path to the base directory of this experiment. This will set the scope for automatic source file discovery. additional_host_info Optional dictionary containing as keys the names of the pieces of host info you want to collect, and as values the functions collecting those pieces of information. save_git_info: Optionally save the git commit hash and the git state (clean or dirty) for all source files. This requires the GitPython package. """ caller_globals = inspect.stack()[1][0].f_globals if name is None: name = "models" super().__init__( path=name, ingredients=ingredients, interactive=interactive, base_dir=base_dir, _caller_globals=caller_globals, save_git_info=save_git_info, ) self.get_models_command = None self.current_run = None # self.command(print_config, unobserved=True) def raise_(ex): raise ex class Model(Ingredient): """ The class that annotates a Dateset of Ba3l experiment a Dataset can be """ MODEL_STRING_PREFIX = "get_instance" def __init__( self, name: str, ingredients: Sequence[Ingredient] = (), interactive: bool = False, base_dir: Optional[PathType] = None, save_git_info: bool = True, ): """ Create a new experiment with the given name and optional ingredients. Parameters ---------- name Optional name of this experiment, defaults to the filename. (Required in interactive mode) ingredients : list[sacred.Ingredient], optional A list of ingredients to be used with this experiment. interactive If set to True will allow the experiment to be run in interactive mode (e.g. IPython or Jupyter notebooks). However, this mode is discouraged since it won't allow storing the source-code or reliable reproduction of the runs. base_dir Optional full path to the base directory of this experiment. This will set the scope for automatic source file discovery. additional_host_info Optional dictionary containing as keys the names of the pieces of host info you want to collect, and as values the functions collecting those pieces of information. save_git_info: Optionally save the git commit hash and the git state (clean or dirty) for all source files. This requires the GitPython package. """ caller_globals = inspect.stack()[1][0].f_globals if name is None: name = "model" self.name = name.rsplit(".", 1)[-1] super().__init__( path=name, ingredients=ingredients, interactive=interactive, base_dir=base_dir, _caller_globals=caller_globals, save_git_info=save_git_info, ) self.get_instance_command = None self.current_run = None self.get_instance = lambda: raise_( NotImplementedError( "Use dataset.dataset_name.dataset to annotate the " "get_dataset function!." ) ) @optional_kwargs_decorator def instance( self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args ): """ Decorator to define a new model. The name of the model is used to get an instance of the model, it will register a command The command can be given a prefix, to restrict its configuration space to a subtree. (see ``capture`` for more information) A command can be made unobserved (i.e. ignoring all observers) by passing the unobserved=True keyword argument. :param function: the function to return a Dataset Object :param prefix: sacred configuration prefix :param unobserved: sacred unobserved :param static_args: static Args to be passed to the function, these arg need not to be serlizable and are not stored in the config :param extra_args: explicit arguments to be add to the config, you can these to override the function default values, for example wraping a config with CMD, then the parameter will be filled with excuting the command specified by CMD string value. CMD string have special context :return: """ self.add_default_args_config(function, prefix, extra_args, static_args=static_args) captured_f = self.capture(function, prefix=prefix, static_args=static_args) captured_f.unobserved = unobserved self.commands[Model.MODEL_STRING_PREFIX] = captured_f self.get_instance = captured_f self.add_config(get_instance=CMD("get_instance")) return captured_f def __getattr__(self, k): if k == "get_instance": return self.__getattribute__("get_instance") super().__getattribute__(k) # @todo maybe run commands from here after running ================================================ FILE: ba3l/ingredients/trainer.py ================================================ import inspect import os from ba3l.util.functions import get_default_kwargs_dict from .datasets import Datasets, raise_ from .models import Models from sacred import Ingredient from typing import Sequence, Optional, List from sacred.utils import PathType, optional_kwargs_decorator class Trainer(Ingredient): """ The class that annotates the main Trainer of Ba3l experiment """ TRAINER_STRING_PREFIX = "get_trainer" ================================================ FILE: ba3l/module.py ================================================ import pytorch_lightning as pl import warnings from abc import ABC import torch.distributed as dist from munch import DefaultMunch try: # loading for pyTorch 1.3 from torch.utils.data import IterableDataset except ImportError: # loading for pyTorch 1.1 import torch warnings.warn( "Your version of pyTorch %s does not support `IterableDataset`," " please upgrade to 1.2+" % torch.__version__, ImportWarning, ) EXIST_ITER_DATASET = False else: EXIST_ITER_DATASET = True try: from apex import amp APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False class Ba3lModule(pl.LightningModule): def __init__(self, experiment): super(Ba3lModule, self).__init__() self.experiment = experiment self.run = experiment.current_run self.config = DefaultMunch.fromDict(experiment.current_run.config) for key,model in experiment.current_run.config['models'].items(): setattr(self, key, experiment.current_run.get_command_function("models."+key+"."+model['instance_cmd'])()) self.save_hyperparameters(self.config) ================================================ FILE: ba3l/plutils/__init__.py ================================================ ================================================ FILE: ba3l/plutils/lr_monitor.py ================================================ # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r""" Learning Rate Monitor ===================== Monitor and logs learning rate for lr schedulers during training. """ from typing import Dict, List, Optional from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException class LearningRateMonitor(Callback): r""" Automatically monitor and logs learning rate for learning rate schedulers during training. Args: logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers at the same interval, set to ``None`` to log at individual interval according to the ``interval`` key of each scheduler. Defaults to ``None``. log_momentum: option to also log the momentum values of the optimizer, if the optimizer has the ``momentum`` or ``betas`` attribute. Defaults to ``False``. Raises: MisconfigurationException: If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import LearningRateMonitor >>> lr_monitor = LearningRateMonitor(logging_interval='step') >>> trainer = Trainer(callbacks=[lr_monitor]) Logging names are automatically determined based on optimizer class name. In case of multiple optimizers of same type, they will be named ``Adam``, ``Adam-1`` etc. If a optimizer has multiple parameter groups they will be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a ``name`` keyword in the construction of the learning rate schdulers Example:: def configure_optimizer(self): optimizer = torch.optim.Adam(...) lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...) 'name': 'my_logging_name' } return [optimizer], [lr_scheduler] """ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False): if logging_interval not in (None, 'step', 'epoch'): raise MisconfigurationException('logging_interval should be `step` or `epoch` or `None`.') self.logging_interval = logging_interval self.log_momentum = log_momentum self.lrs = None self.lr_sch_names = [] def on_train_start(self, trainer, *args, **kwargs): """ Called before training, determines unique names for all lr schedulers in the case of multiple of the same type or in the case of multiple parameter groups Raises: MisconfigurationException: If ``Trainer`` has no ``logger``. """ if not trainer.logger: raise MisconfigurationException( 'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.' ) if not trainer.lr_schedulers: rank_zero_warn( 'You are using `LearningRateMonitor` callback with models that' ' have no learning rate schedulers. Please see documentation' ' for `configure_optimizers` method.', RuntimeWarning ) if self.log_momentum: def _check_no_key(key): return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers) if _check_no_key('momentum') and _check_no_key('betas'): rank_zero_warn( "You have set log_momentum=True, but some optimizers do not" " have momentum. This will log a value 0 for the momentum.", RuntimeWarning ) # Find names for schedulers names = self._find_names(trainer.lr_schedulers) # Initialize for storing values self.lrs = {name: [] for name in names} self.last_momentum_values = {name + "-momentum": None for name in names} def on_train_batch_start(self, trainer, *args, **kwargs): if not self._should_log(trainer): return if self.logging_interval != 'epoch': interval = 'step' if self.logging_interval is None else 'any' latest_stat = self._extract_stats(trainer, interval) if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer, *args, **kwargs): if self.logging_interval != 'step': interval = 'epoch' if self.logging_interval is None else 'any' latest_stat = self._extract_stats(trainer, interval) if latest_stat: trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch) def _extract_stats(self, trainer, interval: str) -> Dict[str, float]: latest_stat = {} for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers): if scheduler['interval'] == interval or interval == 'any': opt = scheduler['scheduler'].optimizer param_groups = opt.param_groups use_betas = 'betas' in opt.defaults for i, pg in enumerate(param_groups): suffix = f'/pg{i + 1}' if len(param_groups) > 1 else '' lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}') latest_stat.update(lr) momentum = self._extract_momentum( param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas ) latest_stat.update(momentum) return latest_stat def _extract_lr(self, param_group, name: str) -> Dict[str, float]: lr = param_group.get('lr') self.lrs[name].append(lr) return {name: lr} def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]: if not self.log_momentum: return {} momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0) self.last_momentum_values[name] = momentum return {name: momentum} def _find_names(self, lr_schedulers) -> List[str]: # Create uniqe names in the case we have multiple of the same learning # rate schduler + multiple parameter groups names = [] for scheduler in lr_schedulers: sch = scheduler['scheduler'] if scheduler['name'] is not None: name = scheduler['name'] else: opt_name = 'lr-' + sch.optimizer.__class__.__name__ i, name = 1, opt_name # Multiple schduler of the same type while True: if name not in names: break i, name = i + 1, f'{opt_name}-{i}' # Multiple param groups for the same schduler param_groups = sch.optimizer.param_groups if len(param_groups) != 1: for i, pg in enumerate(param_groups): temp = f'{name}/pg{i + 1}' names.append(temp) else: names.append(name) self.lr_sch_names.append(name) return names @staticmethod def _should_log(trainer) -> bool: should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop) return should_log ================================================ FILE: ba3l/plutils/progress_bar.py ================================================ import importlib import io import os import sys # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed from typing import Optional, Union from pytorch_lightning.callbacks import Callback,ProgressBarBase , ProgressBar as PlProgressBar from pytorch_lightning.callbacks.progress import tqdm class ProgressBar(PlProgressBar): r""" This is the default progress bar used by Lightning. It prints to `stdout` using the :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; shows total progress over all validation datasets. - **test progress:** only active when testing; shows total progress over all test datasets. For infinite datasets, the progress bar never ends. If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~pytorch_lightning.trainer.trainer.Trainer`: Example:: class LitProgressBar(ProgressBar): def init_validation_tqdm(self): bar = super().init_validation_tqdm() bar.set_description('running validation ...') return bar bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress bar and sets the refresh rate to the value provided to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. """ def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.main_progress_bar is not None has_main_bar = False bar = tqdm( desc='Validating', position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, file=sys.stdout ) return bar def on_epoch_start(self, trainer, pl_module): ProgressBarBase.on_epoch_start(self, trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float('inf'): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_val_batches = 0 total_batches = total_train_batches + total_val_batches reset(self.main_progress_bar, total_batches) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): ProgressBarBase.on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.train_batch_idx, self.total_train_batches): self._update_bar(self.main_progress_bar) self.main_progress_bar.set_postfix(trainer.progress_bar_dict) def on_validation_start(self, trainer, pl_module): ProgressBarBase.on_validation_start(self, trainer, pl_module) if trainer.sanity_checking: reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) else: if self.main_progress_bar is not None: self.main_progress_bar.close() self.val_progress_bar = self.init_validation_tqdm() reset(self.val_progress_bar, self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): ProgressBarBase.on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx, self.total_val_batches): self._update_bar(self.val_progress_bar) #self._update_bar(self.main_progress_bar) def on_validation_end(self, trainer, pl_module): ProgressBarBase.on_validation_end(self, trainer, pl_module) if self.main_progress_bar is not None: self.main_progress_bar.set_postfix(trainer.progress_bar_dict) self.val_progress_bar.close() def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """ The tqdm doesn't support inf values. We have to convert it to None. """ if x == float('inf'): return None return x def reset(bar: tqdm, total: Optional[int] = None) -> None: """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ if not bar.disable: bar.reset(total=convert_inf(total)) ================================================ FILE: ba3l/util/__init__.py ================================================ ================================================ FILE: ba3l/util/functions.py ================================================ import inspect from collections import OrderedDict def get_default_kwargs_dict(f): sig = inspect.signature(f) return OrderedDict( [ (p.name, p.default) for p in sig.parameters.values() if p.default != inspect._empty ] ) ================================================ FILE: ba3l/util/sacred_logger.py ================================================ from pytorch_lightning.utilities import rank_zero_only try: from pytorch_lightning.loggers import Logger as LightningLoggerBase from pytorch_lightning.loggers.logger import rank_zero_experiment except ImportError: from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers.base import rank_zero_experiment from warnings import warn from logging import getLogger try: import sacred except ImportError: raise ImportError("Missing sacred package. Run `pip install sacred`") logger = getLogger(__name__) class SacredLogger(LightningLoggerBase): def __init__(self, sacred_experiment): """Initialize a sacred logger. :param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers already appended. source: https://github.com/expectopatronum/pytorch-lightning/blob/9fcb238ec03e3f0b0378fd058119f1563a11650c/pytorch_lightning/logging/sacred.py """ super().__init__() self.sacred_experiment = sacred_experiment self.experiment_name = sacred_experiment.path self._run_id = None warn('SacredLogger is deprecated', DeprecationWarning, stacklevel=2) @property def experiment(self): return self.sacred_experiment @property def run_id(self): if self._run_id is not None: return self._run_id self._run_id = self.sacred_experiment.get_run_identifier() return self._run_id @rank_zero_only def log_hyperparams(self, params): # probably not needed bc. it is dealt with by sacred pass @rank_zero_only def log_metrics(self, metrics, step=None): for k, v in metrics.items(): if isinstance(v, str): logger.warning(f"Discarding metric with string value {k}={v}") continue self.experiment.log_scalar(k, v, step) @property def name(self): return self.experiment_name @property def version(self): return self.run_id @rank_zero_only def save(self): # Optional. Any code necessary to save logger data goes here # If you implement this, remember to call `super().save()` # at the start of the method (important for aggregation of metrics) super().save() @rank_zero_only def finalize(self, status): # Optional. Any code that needs to be run after training # finishes goes here pass ================================================ FILE: config_updates.py ================================================ from sacred.config_helpers import DynamicIngredient, CMD def add_configs(ex): ''' This functions add generic configuration for the experiments, such as mix-up, architectures, etc... @param ex: Ba3l Experiment @return: ''' @ex.named_config def nomixup(): 'Don\'t apply mix-up (spectrogram level).' use_mixup = False mixup_alpha = 0.3 @ex.named_config def mixup(): ' Apply mix-up (spectrogram level).' use_mixup = True mixup_alpha = 0.3 @ex.named_config def mini_train(): 'limit training/validation to 5 batches for debbuging.' trainer = dict(limit_train_batches=5, limit_val_batches=5) @ex.named_config def passt(): 'use PaSST model' models = { "net": DynamicIngredient("models.passt.model_ing") } @ex.named_config def passt_s_20sec(): 'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 20 seconds' # python ex_audioset.py evaluate_only with passt_s_ap476 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_20sec_p16_s10_ap474", fstride=10, tstride=10, input_tdim=2000) } basedataset = dict(clip_length=20) @ex.named_config def passt_s_30sec(): 'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 30 seconds' # python ex_audioset.py evaluate_only with passt_s_ap476 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_30sec_p16_s10_ap473", fstride=10, tstride=10, input_tdim=3000) } basedataset = dict(clip_length=20) @ex.named_config def passt_s_ap476(): 'use PaSST model pretrained on Audioset (with SWA) ap=476' # python ex_audioset.py evaluate_only with passt_s_ap476 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap476", fstride=10, tstride=10) } @ex.named_config def passt_s_ap4763(): 'use PaSST model pretrained on Audioset (with SWA) ap=4763' # test with: python ex_audioset.py evaluate_only with passt_s_ap4763 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap4763", fstride=10, tstride=10) } @ex.named_config def passt_s_ap472(): 'use PaSST model pretrained on Audioset (no SWA) ap=472' # test with: python ex_audioset.py evaluate_only with passt_s_ap472 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_128_ap472", fstride=10, tstride=10) } @ex.named_config def passt_s_p16_s16_128_ap468(): 'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap' # test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s16_128_ap468", fstride=16, tstride=16) } @ex.named_config def passt_s_swa_p16_s16_128_ap473(): 'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap' # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16) } @ex.named_config def passt_s_swa_p16_s14_128_ap471(): 'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 ' # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s14_128_ap471", fstride=14, tstride=14) } @ex.named_config def passt_s_p16_s14_128_ap469(): 'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 ' # test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s14_128_ap469", fstride=14, tstride=14) } @ex.named_config def passt_s_swa_p16_s12_128_ap473(): 'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 ' # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s12_128_ap473", fstride=12, tstride=12) } @ex.named_config def passt_s_p16_s12_128_ap470(): 'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 ' # test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470 models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s12_128_ap470", fstride=12, tstride=12) } @ex.named_config def ensemble_s10(): 'use ensemble of PaSST models pretrained on Audioset with S10 mAP=.4864' # test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10 models = { "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s10", fstride=None, tstride=None, instance_cmd="get_ensemble_model", # don't call get_model but rather get_ensemble_model arch_list=[ ("passt_s_swa_p16_128_ap476", 10, 10), ("passt_s_swa_p16_128_ap4761", 10, 10), ("passt_s_p16_128_ap472", 10, 10), ] ) } @ex.named_config def ensemble_many(): 'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4956' # test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many models = { "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None, tstride=None, instance_cmd="get_ensemble_model", # don't call get_model but rather get_ensemble_model arch_list=[ ("passt_s_swa_p16_128_ap476", 10, 10), ("passt_s_swa_p16_128_ap4761", 10, 10), ("passt_s_p16_128_ap472", 10, 10), ("passt_s_p16_s12_128_ap470", 12, 12), ("passt_s_swa_p16_s12_128_ap473", 12, 12), ("passt_s_p16_s14_128_ap469", 14, 14), ("passt_s_swa_p16_s14_128_ap471", 14, 14), ("passt_s_swa_p16_s16_128_ap473", 16, 16), ("passt_s_p16_s16_128_ap468", 16, 16), ] ) } @ex.named_config def ensemble_4(): 'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4926' # test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many models = { "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None, tstride=None, instance_cmd="get_ensemble_model", # don't call get_model but rather get_ensemble_model arch_list=[ ("passt_s_swa_p16_128_ap476", 10, 10), ("passt_s_swa_p16_s12_128_ap473", 12, 12), ("passt_s_swa_p16_s14_128_ap471", 14, 14), ("passt_s_swa_p16_s16_128_ap473", 16, 16), ] ) } @ex.named_config def ensemble_5(): 'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.49459' # test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many models = { "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None, tstride=None, instance_cmd="get_ensemble_model", # don't call get_model but rather get_ensemble_model arch_list=[ ("passt_s_swa_p16_128_ap476", 10, 10), ("passt_s_swa_p16_128_ap4761", 10, 10), ("passt_s_swa_p16_s12_128_ap473", 12, 12), ("passt_s_swa_p16_s14_128_ap471", 14, 14), ("passt_s_swa_p16_s16_128_ap473", 16, 16), ] ) } @ex.named_config def ensemble_s16_14(): 'use ensemble of two PaSST models pretrained on Audioset with stride 16 and 14 mAP=.48579' # test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14 models = { "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s16", fstride=None, tstride=None, instance_cmd="get_ensemble_model", # don't call get_model but rather get_ensemble_model arch_list=[ ("passt_s_swa_p16_s14_128_ap471", 14, 14), ("passt_s_swa_p16_s16_128_ap473", 16, 16), ] ) } @ex.named_config def dynamic_roll(): # dynamically roll the spectrograms/waveforms # updates the dataset config basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000) ) # extra commands @ex.command def test_loaders_train_speed(): # test how fast data is being loaded from the data loaders. itr = ex.datasets.training.get_iter() import time start = time.time() print("hello") for i, b in enumerate(itr): if i % 20 == 0: print(f"{i}/{len(itr)}", end="\r") end = time.time() print("totoal time:", end - start) start = time.time() print("retry:") for i, b in enumerate(itr): if i % 20 == 0: print(f"{i}/{len(itr)}", end="\r") end = time.time() print("totoal time:", end - start) ================================================ FILE: environment.yml ================================================ name: ba3l channels: - defaults dependencies: - _libgcc_mutex=0.1 - _openmp_mutex=5.1 - ca-certificates=2022.10.11 - certifi=2022.12.7 - ld_impl_linux-64=2.38 - libffi=3.4.2 - libgcc-ng=11.2.0 - libgomp=11.2.0 - libstdcxx-ng=11.2.0 - ncurses=6.3 - openssl=1.1.1s - pip=22.3.1 - python=3.8.16 - readline=8.2 - setuptools=65.6.3 - sqlite=3.40.1 - tk=8.6.12 - wheel=0.37.1 - xz=5.2.10 - zlib=1.2.13 - pip: - absl-py==1.4.0 - aiohttp==3.8.4 - aiosignal==1.3.1 - appdirs==1.4.4 - async-timeout==4.0.2 - attrs==22.2.0 - audioread==3.0.0 - av==10.0.0 - cachetools==5.3.0 - cffi==1.15.1 - charset-normalizer==3.0.1 - colorama==0.4.6 - decorator==5.1.1 - docopt==0.6.2 - frozenlist==1.3.3 - fsspec==2023.3.0 - future==0.18.3 - gitdb==4.0.10 - gitpython==3.1.31 - google-auth==2.17.0 - google-auth-oauthlib==0.4.6 - grpcio==1.53.0 - h5py==3.8.0 - idna==3.4 - imageio==2.27.0 - importlib-metadata==6.1.0 - joblib==1.2.0 - jsonpickle==3.0.1 - kk-sacred==0.8.4 - lazy-loader==0.2 - librosa==0.10.0.post2 - llvmlite==0.39.1 - markdown==3.4.3 - markupsafe==2.1.2 - msgpack==1.0.5 - multidict==6.0.4 - munch==2.5.0 - numba==0.56.4 - numpy==1.23.5 - nvidia-cublas-cu11==11.10.3.66 - nvidia-cuda-nvrtc-cu11==11.7.99 - nvidia-cuda-runtime-cu11==11.7.99 - nvidia-cudnn-cu11==8.5.0.96 - oauthlib==3.2.2 - packaging==23.0 - pandas==1.5.3 - pillow==9.4.0 - pooch==1.6.0 - protobuf==4.22.1 - py-cpuinfo==9.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycparser==2.21 - python-dateutil==2.8.2 - pytorch-lightning==1.3.0.dev0 - pytz==2023.3 - pyyaml==6.0 - requests==2.28.2 - requests-oauthlib==1.3.1 - rsa==4.9 - scikit-learn==1.2.2 - scipy==1.10.1 - six==1.16.0 - smmap==5.0.0 - soundfile==0.12.1 - soxr==0.3.4 - tensorboard==2.12.0 - tensorboard-data-server==0.7.0 - tensorboard-plugin-wit==1.8.1 - test-tube==0.7.5 - threadpoolctl==3.1.0 - timm==0.4.12 - torch==1.13.1 - torchaudio==0.13.1 - torchmetrics==0.2.0 - torchvision==0.14.1 - tqdm==4.65.0 - typing-extensions==4.4.0 - urllib3==1.26.14 - werkzeug==2.2.3 - wrapt==1.15.0 - yarl==1.8.2 - zipp==3.15.0 ================================================ FILE: environment_old_2021.yml ================================================ name: ba3l channels: - conda-forge - defaults dependencies: - _libgcc_mutex=0.1 - _openmp_mutex=4.5 - _pytorch_select=0.1 - appdirs=1.4.4 - audioread=2.1.9 - blas=1.0 - brotlipy=0.7.0 - bzip2=1.0.8 - c-ares=1.17.1 - ca-certificates=2020.12.5 - cached-property=1.5.2 - cached_property=1.5.2 - certifi=2020.12.5 - cffi=1.14.5 - chardet=4.0.0 - colorama=0.4.4 - cryptography=3.4.6 - cycler=0.10.0 - decorator=4.4.2 - docopt=0.6.2 - ffmpeg=4.3.1 - freetype=2.10.4 - gettext=0.19.8.1 - gitdb=4.0.5 - gitpython=3.1.14 - gmp=6.2.1 - gnutls=3.6.13 - h5py=3.1.0 - hdf5=1.10.6 - idna=2.10 - importlib-metadata=3.7.3 - importlib_metadata=3.7.3 - intel-openmp=2020.2 - joblib=1.0.1 - jpeg=9d - jsonpickle=1.4.1 - kiwisolver=1.3.1 - krb5=1.17.2 - lame=3.100 - lcms2=2.12 - ld_impl_linux-64=2.35.1 - libblas=3.9.0 - libcblas=3.9.0 - libcurl=7.75.0 - libedit=3.1.20191231 - libev=4.33 - libffi=3.3 - libflac=1.3.3 - libgcc-ng=9.3.0 - libgfortran-ng=9.3.0 - libgfortran5=9.3.0 - libgomp=9.3.0 - liblapack=3.9.0 - libllvm10=10.0.1 - libnghttp2=1.43.0 - libogg=1.3.4 - libopenblas=0.3.12 - libopus=1.3.1 - libpng=1.6.37 - librosa=0.8.0 - libsndfile=1.0.31 - libssh2=1.9.0 - libstdcxx-ng=9.3.0 - libtiff=4.2.0 - libvorbis=1.3.7 - libwebp-base=1.2.0 - llvm-openmp=11.1.0 - llvmlite=0.36.0 - lz4-c=1.9.3 - matplotlib-base=3.3.4 - mkl=2020.2 - mkl-service=2.3.0 - munch=2.5.0 - ncurses=6.2 - nettle=3.6 - ninja=1.10.2 - numba=0.53.0 - numpy=1.20.1 - olefile=0.46 - openblas=0.3.12 - openh264=2.1.1 - openssl=1.1.1k - packaging=20.9 - pandas=1.2.3 - pillow=8.1.2 - pip=21.0.1 - pooch=1.3.0 - py-cpuinfo=7.0.0 - pycparser=2.20 - pyopenssl=20.0.1 - pyparsing=2.4.7 - pysocks=1.7.1 - pysoundfile=0.10.3.post1 - python=3.7.10 - python-dateutil=2.8.1 - python_abi=3.7 - pytz=2021.1 - readline=8.0 - requests=2.25.1 - resampy=0.2.2 - scikit-learn=0.24.1 - scipy=1.6.1 - setuptools=49.6.0 - six=1.15.0 - smmap=3.0.5 - sqlite=3.34.0 - threadpoolctl=2.1.0 - tk=8.6.10 - tornado=6.1 - typing_extensions=3.7.4.3 - urllib3=1.26.4 - wrapt=1.12.1 - x264=1!161.3030 - xz=5.2.5 - zipp=3.4.1 - zlib=1.2.11 - zstd=1.4.9 - pip: - absl-py==0.12.0 - aiohttp==3.7.4.post0 - async-timeout==3.0.1 - attrs==20.3.0 - av==8.0.3 - black==20.8b1 - cachetools==4.2.1 - click==7.1.2 - einops==0.3.0 - fsspec==0.8.7 - future==0.18.2 - google-auth==1.28.0 - google-auth-oauthlib==0.4.3 - gpuinfo==1.0.0a7 - grpcio==1.36.1 - imageio==2.9.0 - jedi==0.18.0 - libtmux==0.8.5 - markdown==3.3.4 - multidict==5.1.0 - mypy-extensions==0.4.3 - oauthlib==3.1.0 - parso==0.8.2 - pathspec==0.8.1 - prompt-toolkit==3.0.18 - protobuf==3.15.6 - ptpython==3.0.17 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pydub==0.25.1 - pygments==2.8.1 - pymongo==3.11.3 - pyyaml==5.3.1 - regex==2021.4.4 - requests-oauthlib==1.3.0 - rsa==4.7.2 - tensorboard==2.4.1 - tensorboard-plugin-wit==1.8.0 - test-tube==0.7.5 - timm==0.4.12 - toml==0.10.2 - torch==1.8.1+cu111 - torchaudio==0.8.1 - torchmetrics==0.2.0 - torchvision==0.6.0 - tqdm==4.59.0 - typed-ast==1.4.3 - wcwidth==0.2.5 - werkzeug==1.0.1 - wheel==0.36.2 - yarl==1.6.3 ================================================ FILE: esc50/README.md ================================================ # Experiments on ESC-50 Environmental Sound Classification [ESC-50](https://github.com/karolpiczak/ESC-50) consist of 2000 5-second recordings to be classified to 50 semantical classes. ## Setting up the fine-tuning experiments - Download the prerpocessed (resampled) dataset [esc50.zip](https://github.com/kkoutini/PaSST/releases/download/v.0.0.6/esc50.zip) from the [releases page](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6) and unpack the zip file into a directory (the default path is `./audioset_hdf5s/`). The `base_dir` config in the dataset file ([here](https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py#L35)) should point to the extracted contents of the dataset zip file. - Running the experiments using the common configurations (similar to Audioset) ```shell python3 ex_esc50.py with models.net.s_patchout_t=10 models.net.s_patchout_f=5 basedataset.fold=1 -p ``` ## Pre-trained models Pre-trained models on ESC-50 can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6). In order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 : ```python from hear21passt.base import get_basic_model,get_model_passt import torch # model wrapper, includes Melspectrogram and the default transformer model = get_basic_model(mode="logits") # replace the transformer with one that outputs 50 classes model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50) # load the pre-trained model state dict state_dict = torch.load('/home/khaled/esc50-passt-s-n-f128-p16-s10-fold1-acc.967.pt') # load the weights into the transformer model.net.load_state_dict(state_dict) # example inference model.eval() model = model.cuda() with torch.no_grad(): # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k logits=model(audio_wave) ``` ================================================ FILE: esc50/dataset.py ================================================ import io import os import pathlib import random import av import librosa import torchaudio from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler import torch from ba3l.ingredients.datasets import Dataset import pandas as pd from sacred.config import DynamicIngredient, CMD from scipy.signal import convolve from sklearn import preprocessing from torch.utils.data import Dataset as TorchDataset import numpy as np import h5py from helpers.audiodatasets import PreprocessDataset LMODE = os.environ.get("LMODE", False) dataset = Dataset('Esc50') @dataset.config def default_config(): name = 'esc50' # dataset name normalize = False # normalize dataset subsample = False # subsample squares from the dataset roll = True # apply roll augmentation fold = 1 base_dir = "audioset_hdf5s/esc50/" # base directory of the dataset as downloaded if LMODE: base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/esc50/" meta_csv = base_dir + "meta/esc50.csv" audio_path = base_dir + "audio_32k/" ir_path = base_dir + "irs/" num_of_classes = 50 def decode_mp3(mp3_arr): """ decodes an array if uint8 representing an mp3 file :rtype: np.array """ container = av.open(io.BytesIO(mp3_arr.tobytes())) stream = next(s for s in container.streams if s.type == 'audio') # print(stream) a = [] for i, packet in enumerate(container.demux(stream)): for frame in packet.decode(): a.append(frame.to_ndarray().reshape(-1)) waveform = np.concatenate(a) if waveform.dtype != 'float32': raise RuntimeError("Unexpected wave type") return waveform def pad_or_truncate(x, audio_length): """Pad all audio to specific length.""" if len(x) <= audio_length: return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) else: return x[0: audio_length] irs_arr = None @dataset.command def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None): if not ir_augment: return global irs_arr if irs_arr is None: all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')] all_paths = sorted(all_paths) if cut_irs_offset is not None: all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10] all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths] print("will use these IRs:") for i in range(len(all_paths_name)): print(i, ": ", all_paths_name[i]) _run.info["ir_devices"] = all_paths_name irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths] return irs_arr[int(np.random.randint(0, len(irs_arr)))] @dataset.command def pydub_augment(waveform, gain_augment=7, ir_augment=0): if ir_augment and torch.rand(1) < ir_augment: ir = get_ir_sample() waveform = convolve(waveform, ir, 'full') if gain_augment: gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment amp = 10 ** (gain / 20) waveform = waveform * amp return waveform class MixupDataset(TorchDataset): """ Mixing Up wave forms """ def __init__(self, dataset, beta=2, rate=0.5): self.beta = beta self.rate = rate self.dataset = dataset print(f"Mixing up waveforms from dataset of len {len(dataset)}") def __getitem__(self, index): if torch.rand(1) < self.rate: x1, f1, y1 = self.dataset[index] idx2 = torch.randint(len(self.dataset), (1,)).item() x2, f2, y2 = self.dataset[idx2] l = np.random.beta(self.beta, self.beta) l = max(l, 1. - l) x1 = x1 - x1.mean() x2 = x2 - x2.mean() x = (x1 * l + x2 * (1. - l)) x = x - x.mean() return x, f1, (y1 * l + y2 * (1. - l)) return self.dataset[index] def __len__(self): return len(self.dataset) class AudioSetDataset(TorchDataset): def __init__(self, meta_csv, audiopath, fold, train=False, sample_rate=32000, classes_num=527, clip_length=5, augment=False): """ Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav """ self.sample_rate = sample_rate self.meta_csv = meta_csv self.df = pd.read_csv(meta_csv) if train: # training all except this print(f"Dataset training fold {fold} selection out of {len(self.df)}") self.df = self.df[self.df.fold != fold] print(f" for training remains {len(self.df)}") else: print(f"Dataset testing fold {fold} selection out of {len(self.df)}") self.df = self.df[self.df.fold == fold] print(f" for testing remains {len(self.df)}") self.clip_length = clip_length * sample_rate self.sr = sample_rate self.classes_num = classes_num self.augment = augment self.audiopath=audiopath if augment: print(f"Will agument data from {meta_csv}") def __len__(self): return len(self.df) def __getitem__(self, index): """Load waveform and target of an audio clip. Args: meta: { 'hdf5_path': str, 'index_in_hdf5': int} Returns: data_dict: { 'audio_name': str, 'waveform': (clip_samples,), 'target': (classes_num,)} """ row = self.df.iloc[index] #waveform = decode_mp3(np.fromfile(self.audiopath + row.filename, dtype='uint8')) waveform, _ = librosa.load(self.audiopath + row.filename, sr=self.sr, mono=True) if self.augment: waveform = pydub_augment(waveform) waveform = pad_or_truncate(waveform, self.clip_length) waveform = self.resample(waveform) target = row.target return waveform.reshape(1, -1), row.filename, target def resample(self, waveform): """Resample. Args: waveform: (clip_samples,) Returns: (resampled_clip_samples,) """ if self.sample_rate == 32000: return waveform elif self.sample_rate == 16000: return waveform[0:: 2] elif self.sample_rate == 8000: return waveform[0:: 4] else: raise Exception('Incorrect sample rate!') @dataset.command def get_base_training_set(meta_csv, audio_path, fold=1): ds = AudioSetDataset(meta_csv, audio_path, fold, train=True, augment=True) return ds @dataset.command def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"), epoch_len=100000, sampler_replace=False): num_nodes = int(os.environ.get('num_nodes', 1)) ddp = int(os.environ.get('DDP', 1)) num_nodes = max(ddp, num_nodes) print("num_nodes= ", num_nodes) rank = int(os.environ.get('NODE_RANK', 0)) return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights, num_samples=epoch_len, replacement=sampler_replace), dataset=range(epoch_len), num_replicas=num_nodes, rank=rank, ) @dataset.command def get_base_test_set(meta_csv, audio_path, fold=1): ds = AudioSetDataset(meta_csv, audio_path, fold, train=False) return ds @dataset.command(prefix='roll_conf') def get_roll_func(axis=1, shift=None, shift_range=50): print("rolling...") def roll_func(b): x, i, y = b x = torch.as_tensor(x) sf = shift if shift is None: sf = int(np.random.random_integers(-shift_range, shift_range)) global FirstTime return x.roll(sf, axis), i, y return roll_func @dataset.command def get_training_set(normalize, roll, wavmix=False): ds = get_base_training_set() get_ir_sample() if normalize: print("normalized train!") fill_norms() ds = PreprocessDataset(ds, norm_func) if roll: ds = PreprocessDataset(ds, get_roll_func()) if wavmix: ds = MixupDataset(ds) return ds @dataset.command def get_test_set(normalize): ds = get_base_test_set() if normalize: print("normalized test!") fill_norms() ds = PreprocessDataset(ds, norm_func) return ds @dataset.command def print_conf(_config): print("Config of ", dataset.path, id(dataset)) print(_config) print() class DistributedSamplerWrapper(DistributedSampler): def __init__( self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True): super(DistributedSamplerWrapper, self).__init__( dataset, num_replicas, rank, shuffle) # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238 self.sampler = sampler def __iter__(self): if self.sampler.generator is None: self.sampler.generator = torch.Generator() self.sampler.generator.manual_seed(self.seed + self.epoch) indices = list(self.sampler) if self.epoch == 0: print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n") indices = indices[self.rank:self.total_size:self.num_replicas] return iter(indices) if __name__ == "__main__": from sacred import Experiment ex = Experiment("test_dataset", ingredients=[dataset]) @ex.automain def default_command(): ex.current_run.get_command_function("print_config")() get_base_training_set() ds = get_test_set() print(ds[0]) ds = get_training_set() print(ds[0]) print("get_base_training_set", len(get_base_training_set())) print("get_base_test_set", len(get_base_test_set())) print("get_training_set", len(get_training_set())) print("get_test_set", len(get_test_set())) ================================================ FILE: ex_audioset.py ================================================ import os import sys import PIL import pytorch_lightning import torch from pytorch_lightning.callbacks import ModelCheckpoint from sacred.config_helpers import DynamicIngredient, CMD from torch.nn import functional as F import numpy as np import wandb from ba3l.experiment import Experiment from ba3l.module import Ba3lModule from torch.utils.data import DataLoader from config_updates import add_configs from helpers.mixup import my_mixup from helpers.models_size import count_non_zero_params from helpers.ramp import exp_warmup_linear_down, cosine_cycle from helpers.workersinit import worker_init_fn from sklearn import metrics from pytorch_lightning import Trainer as plTrainer from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import LearningRateMonitor ex = Experiment("audioset") # Example call with all the default config: # python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base" # with 2 gpus: # DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU" # capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config # now you can use in the cmd trainer.precision=16 for example get_trainer = ex.command(plTrainer, prefix="trainer") # capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line get_logger = ex.command(WandbLogger, prefix="wandb") wandb_logger = None # define datasets and loaders get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12, num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_full_training_set"), sampler=CMD("/basedataset.get_ft_weighted_sampler")) get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), validate=True, batch_size=20, num_workers=16, dataset=CMD("/basedataset.get_test_set")) @ex.config def default_conf(): cmd = " ".join(sys.argv) # command line arguments saque_cmd = os.environ.get("SAQUE_CMD", "").strip() saque_id = os.environ.get("SAQUE_ID", "").strip() slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip() if os.environ.get("SLURM_ARRAY_JOB_ID", False): slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID", "").strip() process_id = os.getpid() models = { "net": DynamicIngredient("models.passt.model_ing", arch="passt_deit_bd_p16_384", n_classes=527, s_patchout_t=40, s_patchout_f=4), # network config "mel": DynamicIngredient("models.preprocess.model_ing", instance_cmd="AugmentMelSTFT", n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10, fmax_aug_range=2000) } basedataset = DynamicIngredient("audioset.dataset.dataset", wavmix=1) wandb = dict(project="passt_audioset2", log_model=True) watch_model = True trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, precision=16, reload_dataloaders_every_epoch=True) lr = 0.00002 # learning rate use_mixup = True mixup_alpha = 0.3 compile = True # compile the model, requires pytorch >= 2.0 # register extra possible configs add_configs(ex) @ex.command def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01, schedule_mode="exp_lin"): if schedule_mode == "exp_lin": return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value) if schedule_mode == "cos_cyc": return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value) raise RuntimeError( f"schedule_mode={schedule_mode} Unknown for a lambda funtion.") @ex.command def get_lr_scheduler(optimizer, schedule_mode): if schedule_mode in {"exp_lin", "cos_cyc"}: return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda()) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") @ex.command def get_optimizer(params, lr, adamw=True, weight_decay=0.0001): if adamw: print(f"\nUsing adamw weight_decay={weight_decay}!\n") return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay) return torch.optim.Adam(params, lr=lr) class M(Ba3lModule): def __init__(self, experiment): self.mel = None self.da_net = None super(M, self).__init__(experiment) self.use_mixup = self.config.use_mixup or False self.mixup_alpha = self.config.mixup_alpha desc, sum_params, sum_non_zero = count_non_zero_params(self.net) self.experiment.info["start_sum_params"] = sum_params self.experiment.info["start_sum_params_non_zero"] = sum_non_zero # in case we need embedings for the DA self.net.return_embed = True self.dyn_norm = self.config.dyn_norm self.do_swa = False self.distributed_mode = self.config.trainer.num_nodes > 1 if self.config.compile: # pt 2 magic print("\n\nCompiling the model pytorch 2... \n\n") self.net = torch.compile(self.net) # compile only the net, not the mel #self.mel = torch.compile(self.mel) def forward(self, x): return self.net(x) def mel_forward(self, x): old_shape = x.size() x = x.reshape(-1, old_shape[2]) x = self.mel(x) x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) if self.dyn_norm: if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"): tr_m, tr_std = get_dynamic_norm(self) self.register_buffer('tr_m', tr_m) self.register_buffer('tr_std', tr_std) x = (x - self.tr_m) / self.tr_std return x def training_step(self, batch, batch_idx): # REQUIRED x, f, y = batch if self.mel: x = self.mel_forward(x) if self.global_step < 5: images = [ wandb.Image( PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"), caption="spectrograms", ) for i in x[:, 0, :, :].cpu().numpy()] wandb.log({"spectrograms": images}) # wandb_logger.log_image(key="spectrograms", images=[i for i in x[:,0,:,:].cpu().numpy()]) orig_x = x batch_size = len(y) rn_indices, lam = None, None if self.use_mixup: rn_indices, lam = my_mixup(batch_size, self.mixup_alpha) lam = lam.to(x.device) x = x * lam.reshape(batch_size, 1, 1, 1) + \ x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1)) y_hat, embed = self.forward(x) if self.use_mixup: y_mix = y * lam.reshape(batch_size, 1) + \ y[rn_indices] * (1. - lam.reshape(batch_size, 1)) samples_loss = F.binary_cross_entropy_with_logits( y_hat, y_mix, reduction="none") loss = samples_loss.mean() samples_loss = samples_loss.detach() else: samples_loss = F.binary_cross_entropy_with_logits( y_hat, y, reduction="none") loss = samples_loss.mean() samples_loss = samples_loss.detach() results = {"loss": loss, } self.log('trainer/lr', self.trainer.optimizers[0].param_groups[0]['lr']) self.log('epoch', self.current_epoch) self.log("training.loss", loss.detach()) return results def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() logs = {'train.loss': avg_loss, # 'step': self.current_epoch } self.log_dict(logs, sync_dist=True) def predict(self, batch, batch_idx: int, dataloader_idx: int = None): x, f, y = batch if self.mel: x = self.mel_forward(x) y_hat, _ = self.forward(x) return f, y_hat def validation_step(self, batch, batch_idx): x, f, y = batch if self.mel: x = self.mel_forward(x) if self.global_step < 5: images = [ wandb.Image( PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"), caption="validation_spectrograms", ) for i in x[:, 0, :, :].cpu().numpy()] wandb.log({"validation_spectrograms": images}) # wandb_logger.log_image(key="validation_spectrograms", images=[ # i for i in x[:, 0, :, :].cpu().numpy()]) results = {} model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: y_hat, _ = net(x) samples_loss = F.binary_cross_entropy_with_logits(y_hat, y) loss = samples_loss.mean() out = torch.sigmoid(y_hat.detach()) # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False) results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach()} results = {k: v.cpu() for k, v in results.items()} return results def validation_epoch_end(self, outputs): model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean() out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0) target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0) try: average_precision = metrics.average_precision_score( target.float().numpy(), out.float().numpy(), average=None) except ValueError: average_precision = np.array([np.nan] * 527) try: roc = metrics.roc_auc_score( target.numpy(), out.numpy(), average=None) except ValueError: roc = np.array([np.nan] * 527) logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(), net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(), net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),} #'step': torch.as_tensor(self.current_epoch).cuda()} # torch.save(average_precision, # f"ap_perclass_{average_precision.mean()}.pt") # print(average_precision) self.log_dict(logs, sync_dist=True) if self.distributed_mode: allout = self.all_gather(out) alltarget = self.all_gather(target) average_precision = metrics.average_precision_score( alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(), allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None) if self.trainer.is_global_zero: logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(), # 'step': torch.as_tensor(self.current_epoch).cuda() } self.log_dict(logs, sync_dist=False) else: self.log_dict( {net_name + "allap": logs[net_name + 'ap'], #'step': logs['step'] } , sync_dist=True) def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) optimizer = get_optimizer(self.parameters()) # torch.optim.Adam(self.parameters(), lr=self.config.lr) return { 'optimizer': optimizer, 'lr_scheduler': get_lr_scheduler(optimizer) } def configure_callbacks(self): return get_extra_checkpoint_callback() + get_extra_swa_callback() + [LearningRateMonitor(logging_interval='epoch')] @ex.command def get_dynamic_norm(model, dyn_norm=False): if not dyn_norm: return None, None raise RuntimeError('no dynamic norm supported yet.') @ex.command def get_extra_checkpoint_callback(save_last_n=None): if save_last_n is None: return [] return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')] @ex.command def get_extra_swa_callback(swa=True, swa_epoch_start=50, swa_freq=5): if not swa: return [] print("\n Using swa!\n") if pytorch_lightning.__version__ <= "1.5": from helpers.swa_legacy import StochasticWeightAveraging else: from helpers.swa_callback import StochasticWeightAveraging return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)] @ex.command def main(_run, _config, _log, _rnd, _seed, watch_model=True): global wandb_logger wandb_logger = get_logger() trainer = get_trainer(logger=wandb_logger) train_loader = get_train_loader() val_loader = get_validate_loader() modul = M(ex) if watch_model: # change log frequency of gradients and parameters (100 steps by default) wandb_logger.watch(modul, log_freq=1000, log="all" ) if pytorch_lightning.__version__ <= "1.5": trainer.fit( modul, train_dataloader=train_loader, val_dataloaders=val_loader, ) else: trainer.fit( modul, train_dataloaders=train_loader, val_dataloaders=val_loader, ) return {"done": True} @ex.command def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=12): ''' Test training speed of a model @param _run: @param _config: @param _log: @param _rnd: @param _seed: @param speed_test_batch_size: the batch size during the test @return: ''' modul = M(ex) modul = modul.cuda() batch_size = speed_test_batch_size print(f"\nBATCH SIZE : {batch_size}\n") test_length = 100 print(f"\ntest_length : {test_length}\n") x = torch.ones([batch_size, 1, 128, 998]).cuda() target = torch.ones([batch_size, 527]).cuda() # one passe net = modul.net # net(x) scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True # net = torch.jit.trace(net,(x,)) net = torch.compile(net) optimizer = torch.optim.SGD(net.parameters(), lr=0.001) print("warmup") import time torch.cuda.synchronize() t1 = time.time() for i in range(10): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits( y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('warmup done:', (t2 - t1)) torch.cuda.synchronize() t1 = time.time() print("testing speed") for i in range(test_length): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits( y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('test done:', (t2 - t1)) print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second") @ex.command def evaluate_only(_run, _config, _log, _rnd, _seed): # force overriding the config, not logged = not recommended trainer = get_trainer(logger=get_logger()) val_loader = get_validate_loader() modul = M(ex) modul.val_dataloader = None #trainer.val_dataloaders = None print(f"\n\nValidation len={len(val_loader)}\n") res = trainer.validate(modul, dataloaders=val_loader) print("\n\n Validtaion:") print(res) @ex.command def test_loaders(): ''' get one sample from each loader for debbuging @return: ''' for i, b in enumerate(ex.datasets.training.get_iter()): print(b) break for i, b in enumerate(ex.datasets.test.get_iter()): print(b) break def set_default_json_pickle(obj): if isinstance(obj, set): return list(obj) raise TypeError @ex.command def preload_mp3(all_y=CMD("/basedataset.preload_mp3")): ''' read the dataset sequentially, useful if you have a network cache @param all_y: the dataset preload command @return: ''' print(all_y.shape) def multiprocessing_run(rank, word_size): print("rank ", rank, os.getpid()) print("word_size ", word_size) os.environ['NODE_RANK'] = str(rank) os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[ rank] argv = sys.argv if rank != 0: print(f"Unobserved {os.getpid()} with rank {rank}") argv = argv + ["-u"] # only rank 0 is observed if "with" not in argv: argv = argv + ["with"] argv = argv + \ [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"] print(argv) @ex.main def default_command(): return main() ex.run_commandline(argv) if __name__ == '__main__': # set DDP=2 forks two processes to run on two GPUs # the environment variable "DDP" define the number of processes to fork # With two 2x 2080ti you can train the full model to .47 in around 24 hours # you may need to set NCCL_P2P_DISABLE=1 word_size = os.environ.get("DDP", None) if word_size: import random word_size = int(word_size) print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n") os.environ['MASTER_ADDR'] = '127.0.0.1' # plz no collisions os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" os.environ['PL_IN_DDP_SUBPROCESS'] = '1' for rank in range(word_size): pid = os.fork() if pid == 0: print("Child Forked ") multiprocessing_run(rank, word_size) exit(0) pid, exit_code = os.wait() print(pid, exit_code) exit(0) print("__main__ is running pid", os.getpid(), "in module main: ", __name__) @ex.automain def default_command(): return main() ================================================ FILE: ex_esc50.py ================================================ import os import sys import torch from pytorch_lightning.callbacks import ModelCheckpoint from sacred.config_helpers import DynamicIngredient, CMD from torch.nn import functional as F import numpy as np from ba3l.experiment import Experiment from ba3l.module import Ba3lModule from torch.utils.data import DataLoader from config_updates import add_configs from helpers.mixup import my_mixup from helpers.models_size import count_non_zero_params from helpers.ramp import exp_warmup_linear_down, cosine_cycle from helpers.workersinit import worker_init_fn from sklearn import metrics from pytorch_lightning import Trainer as plTrainer from pytorch_lightning.loggers import WandbLogger ex = Experiment("passt_esc50") # Example call with all the default config: # python ex_esc50.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base" # with 2 gpus: # DDP=2 python ex_esc50.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base" # capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config get_trainer = ex.command(plTrainer, prefix="trainer") # capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line get_logger = ex.command(WandbLogger, prefix="wandb") # define datasets and loaders get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12, num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"), ) get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), validate=True, batch_size=20, num_workers=16, dataset=CMD("/basedataset.get_test_set")) @ex.config def default_conf(): cmd = " ".join(sys.argv) saque_cmd = os.environ.get("SAQUE_CMD", "").strip() saque_id = os.environ.get("SAQUE_ID", "").strip() slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip() if os.environ.get("SLURM_ARRAY_JOB_ID", False): slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID", "").strip() process_id = os.getpid() models = { "net": DynamicIngredient("models.passt.model_ing", n_classes=50, s_patchout_t=10, s_patchout_f=3), "mel": DynamicIngredient("models.preprocess.model_ing", instance_cmd="AugmentMelSTFT", n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=80, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10, fmax_aug_range=2000) } basedataset = DynamicIngredient("esc50.dataset.dataset") trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, reload_dataloaders_every_epoch=True) wandb = dict(project="passt_esc50", log_model=True) lr = 0.00001 use_mixup = True mixup_alpha = 0.3 # register extra possible configs add_configs(ex) @ex.command def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01, schedule_mode="exp_lin"): if schedule_mode == "exp_lin": return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value) if schedule_mode == "cos_cyc": return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.") @ex.command def get_lr_scheduler(optimizer, schedule_mode): if schedule_mode in {"exp_lin", "cos_cyc"}: return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda()) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") @ex.command def get_optimizer(params, lr, adamw=True, weight_decay=0.0001): if adamw: print(f"\nUsing adamw weight_decay={weight_decay}!\n") return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay) return torch.optim.Adam(params, lr=lr) class M(Ba3lModule): def __init__(self, experiment): self.mel = None self.da_net = None super(M, self).__init__(experiment) self.use_mixup = self.config.use_mixup or False self.mixup_alpha = self.config.mixup_alpha desc, sum_params, sum_non_zero = count_non_zero_params(self.net) self.experiment.info["start_sum_params"] = sum_params self.experiment.info["start_sum_params_non_zero"] = sum_non_zero # in case we need embedings for the DA self.net.return_embed = True self.dyn_norm = self.config.dyn_norm self.do_swa = False self.distributed_mode = self.config.trainer.num_nodes > 1 def forward(self, x): return self.net(x) def mel_forward(self, x): old_shape = x.size() x = x.reshape(-1, old_shape[2]) x = self.mel(x) x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) if self.dyn_norm: if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"): tr_m, tr_std = get_dynamic_norm(self) self.register_buffer('tr_m', tr_m) self.register_buffer('tr_std', tr_std) x = (x - self.tr_m) / self.tr_std return x def training_step(self, batch, batch_idx): # REQUIRED x, f, y = batch if self.mel: x = self.mel_forward(x) orig_x = x batch_size = len(y) rn_indices, lam = None, None if self.use_mixup: rn_indices, lam = my_mixup(batch_size, self.mixup_alpha) lam = lam.to(x.device) x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1)) y_hat, embed = self.forward(x) if self.use_mixup: # y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1)) samples_loss = (F.cross_entropy(y_hat, y, reduction="none") * lam.reshape(batch_size) + F.cross_entropy(y_hat, y[rn_indices], reduction="none") * (1. - lam.reshape(batch_size))) loss = samples_loss.mean() loss = samples_loss.mean() samples_loss = samples_loss.detach() else: samples_loss = F.cross_entropy(y_hat, y, reduction="none") loss = samples_loss.mean() samples_loss = samples_loss.detach() results = {"loss": loss, } return results def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() logs = {'train.loss': avg_loss, 'step': self.current_epoch} self.log_dict(logs, sync_dist=True) def predict(self, batch, batch_idx: int, dataloader_idx: int = None): x, f, y = batch if self.mel: x = self.mel_forward(x) y_hat, _ = self.forward(x) return f, y_hat def validation_step(self, batch, batch_idx): x, f, y = batch if self.mel: x = self.mel_forward(x) results = {} model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: y_hat, _ = net(x) samples_loss = F.cross_entropy(y_hat, y, reduction="none") loss = samples_loss.mean() _, preds = torch.max(y_hat, dim=1) n_correct_pred_per_sample = (preds == y) n_correct_pred = n_correct_pred_per_sample.sum() # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False) results = {**results, net_name + "val_loss": loss, net_name + "n_correct_pred": torch.as_tensor(n_correct_pred), net_name + "n_pred":torch.as_tensor( len(y)) } results = {k: v.cpu() for k, v in results.items()} return results def validation_epoch_end(self, outputs): model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean() val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs) logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(), net_name + 'acc': torch.as_tensor(val_acc).cuda(), 'step': torch.as_tensor(self.current_epoch).cuda()} self.log_dict(logs, sync_dist=True) def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) optimizer = get_optimizer(self.parameters()) # torch.optim.Adam(self.parameters(), lr=self.config.lr) return { 'optimizer': optimizer, 'lr_scheduler': get_lr_scheduler(optimizer) } def configure_callbacks(self): return get_extra_checkpoint_callback() + get_extra_swa_callback() @ex.command def get_dynamic_norm(model, dyn_norm=False): if not dyn_norm: return None, None raise RuntimeError('no dynamic norm supported yet.') @ex.command def get_extra_checkpoint_callback(save_last_n=None): if save_last_n is None: return [] return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')] @ex.command def get_extra_swa_callback(swa=True, swa_epoch_start=2, swa_freq=1): if not swa: return [] print("\n Using swa!\n") from helpers.swa_callback import StochasticWeightAveraging return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)] @ex.command def main(_run, _config, _log, _rnd, _seed): trainer = get_trainer() train_loader = get_train_loader() val_loader = get_validate_loader() modul = M(ex) trainer.fit( modul, train_dataloaders=train_loader, val_dataloaders=val_loader, ) return {"done": True} @ex.command def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100): ''' Test training speed of a model @param _run: @param _config: @param _log: @param _rnd: @param _seed: @param speed_test_batch_size: the batch size during the test @return: ''' modul = M(ex) modul = modul.cuda() batch_size = speed_test_batch_size print(f"\nBATCH SIZE : {batch_size}\n") test_length = 100 print(f"\ntest_length : {test_length}\n") x = torch.ones([batch_size, 1, 128, 998]).cuda() target = torch.ones([batch_size, 527]).cuda() # one passe net = modul.net # net(x) scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True # net = torch.jit.trace(net,(x,)) optimizer = torch.optim.SGD(net.parameters(), lr=0.001) print("warmup") import time torch.cuda.synchronize() t1 = time.time() for i in range(10): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('warmup done:', (t2 - t1)) torch.cuda.synchronize() t1 = time.time() print("testing speed") for i in range(test_length): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('test done:', (t2 - t1)) print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second") @ex.command def evaluate_only(_run, _config, _log, _rnd, _seed): # force overriding the config, not logged = not recommended trainer = get_trainer() train_loader = get_train_loader() val_loader = get_validate_loader() modul = M(ex) modul.val_dataloader = None #trainer.val_dataloaders = None print(f"\n\nValidation len={len(val_loader)}\n") res = trainer.validate(modul, dataloaders=val_loader) print("\n\n Validtaion:") print(res) @ex.command def test_loaders(): ''' get one sample from each loader for debbuging @return: ''' for i, b in enumerate(ex.datasets.training.get_iter()): print(b) break for i, b in enumerate(ex.datasets.test.get_iter()): print(b) break def set_default_json_pickle(obj): if isinstance(obj, set): return list(obj) raise TypeError def multiprocessing_run(rank, word_size): print("rank ", rank, os.getpid()) print("word_size ", word_size) os.environ['NODE_RANK'] = str(rank) os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank] argv = sys.argv if rank != 0: print(f"Unobserved {os.getpid()} with rank {rank}") argv = argv + ["-u"] # only rank 0 is observed if "with" not in argv: argv = argv + ["with"] argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"] print(argv) @ex.main def default_command(): return main() ex.run_commandline(argv) if __name__ == '__main__': # set DDP=2 forks two processes to run on two GPUs # the environment variable "DDP" define the number of processes to fork # With two 2x 2080ti you can train the full model to .47 in around 24 hours # you may need to set NCCL_P2P_DISABLE=1 word_size = os.environ.get("DDP", None) if word_size: import random word_size = int(word_size) print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n") os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions os.environ['PL_IN_DDP_SUBPROCESS'] = '1' for rank in range(word_size): pid = os.fork() if pid == 0: print("Child Forked ") multiprocessing_run(rank, word_size) exit(0) pid, exit_code = os.wait() print(pid, exit_code) exit(0) print("__main__ is running pid", os.getpid(), "in module main: ", __name__) @ex.automain def default_command(): return main() ================================================ FILE: ex_fsd50k.py ================================================ import os import sys import torch from pytorch_lightning.callbacks import ModelCheckpoint from sacred.config_helpers import DynamicIngredient, CMD from torch.nn import functional as F import numpy as np from ba3l.experiment import Experiment from ba3l.module import Ba3lModule from torch.utils.data import DataLoader from config_updates import add_configs from helpers.mixup import my_mixup from helpers.models_size import count_non_zero_params from helpers.ramp import exp_warmup_linear_down, cosine_cycle from helpers.workersinit import worker_init_fn from sklearn import metrics from pytorch_lightning import Trainer as plTrainer from pytorch_lightning.loggers import WandbLogger ex = Experiment("fsd50k") # capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config get_trainer = ex.command(plTrainer, prefix="trainer") # capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line get_logger = ex.command(WandbLogger, prefix="wandb") # Example call with all the default config: # python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base" # with 2 gpus: # DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU" # define datasets and loaders get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12, num_workers=16, shuffle=True, dataset=CMD("/basedataset.get_training_set"), ) get_validate_loader = ex.datasets.valid.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), validate=True, batch_size=10, num_workers=16, dataset=CMD("/basedataset.get_valid_set")) get_eval_loader = ex.datasets.eval.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), validate=True, batch_size=10, num_workers=16, dataset=CMD("/basedataset.get_eval_set")) @ex.named_config def variable_eval(): basedataset = dict(variable_eval=True) datasets = dict(valid=dict(batch_size=1), eval=dict(batch_size=1)) @ex.config def default_conf(): cmd = " ".join(sys.argv) # command line arguments saque_cmd = os.environ.get("SAQUE_CMD", "").strip() saque_id = os.environ.get("SAQUE_ID", "").strip() slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip() if os.environ.get("SLURM_ARRAY_JOB_ID", False): slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID", "").strip() process_id = os.getpid() models = { "net": DynamicIngredient("models.passt.model_ing", n_classes=200, s_patchout_t=10, s_patchout_f=4), # network config "mel": DynamicIngredient("models.preprocess.model_ing", instance_cmd="AugmentMelSTFT", n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=0, timem=0, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10, fmax_aug_range=2000) } # set the default name for wandb logger wandb = dict(project="passt_fsd50k", log_model=True) basedataset = DynamicIngredient("fsd50k.dataset.dataset", wavmix=1) trainer = dict(max_epochs=50, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, reload_dataloaders_every_epoch=True) lr = 0.00001 # learning rate use_mixup = True mixup_alpha = 0.3 # register extra possible configs add_configs(ex) @ex.command def get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_len=10, last_lr_value=0.01, schedule_mode="exp_lin"): if schedule_mode == "exp_lin": return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value) if schedule_mode == "cos_cyc": return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.") @ex.command def get_lr_scheduler(optimizer, schedule_mode): if schedule_mode in {"exp_lin", "cos_cyc"}: return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda()) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") @ex.command def get_optimizer(params, lr, adamw=True, weight_decay=0.0001): if adamw: print(f"\nUsing adamw weight_decay={weight_decay}!\n") return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay) return torch.optim.Adam(params, lr=lr) class M(Ba3lModule): def __init__(self, experiment): self.mel = None self.da_net = None super(M, self).__init__(experiment) self.use_mixup = self.config.use_mixup or False self.mixup_alpha = self.config.mixup_alpha desc, sum_params, sum_non_zero = count_non_zero_params(self.net) self.experiment.info["start_sum_params"] = sum_params self.experiment.info["start_sum_params_non_zero"] = sum_non_zero # in case we need embedings for the DA self.net.return_embed = True self.dyn_norm = self.config.dyn_norm self.do_swa = False self.valid_names = ["valid", "eval"] self.distributed_mode = self.config.trainer.num_nodes > 1 def forward(self, x): return self.net(x) def mel_forward(self, x): old_shape = x.size() x = x.reshape(-1, old_shape[2]) x = self.mel(x) x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) if self.dyn_norm: if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"): tr_m, tr_std = get_dynamic_norm(self) self.register_buffer('tr_m', tr_m) self.register_buffer('tr_std', tr_std) x = (x - self.tr_m) / self.tr_std return x def training_step(self, batch, batch_idx): # REQUIRED x, f, y = batch if self.mel: x = self.mel_forward(x) orig_x = x batch_size = len(y) rn_indices, lam = None, None if self.use_mixup: rn_indices, lam = my_mixup(batch_size, self.mixup_alpha) lam = lam.to(x.device) x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1)) y_hat, embed = self.forward(x) if self.use_mixup: y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1)) samples_loss = F.binary_cross_entropy_with_logits( y_hat, y_mix, reduction="none") loss = samples_loss.mean() samples_loss = samples_loss.detach() else: samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") loss = samples_loss.mean() samples_loss = samples_loss.detach() results = {"loss": loss, } return results def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() logs = {'train.loss': avg_loss, 'step': self.current_epoch} self.log_dict(logs, sync_dist=True) def predict(self, batch, batch_idx: int, dataloader_idx: int = None): x, f, y = batch if self.mel: x = self.mel_forward(x) y_hat, _ = self.forward(x) return f, y_hat def validation_step(self, batch, batch_idx, dataloader_idx): x, f, y = batch if self.mel: x = self.mel_forward(x) results = {} model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: y_hat, _ = net(x) samples_loss = F.binary_cross_entropy_with_logits(y_hat, y) loss = samples_loss.mean() out = torch.sigmoid(y_hat.detach()) # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False) results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach()} results = {k: v.cpu() for k, v in results.items()} return results def validation_epoch_end(self, outputs): for idx, one_outputs in enumerate(outputs): set_name = self.valid_names[idx] + "_" model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: avg_loss = torch.stack([x[net_name + 'val_loss'] for x in one_outputs]).mean() out = torch.cat([x[net_name + 'out'] for x in one_outputs], dim=0) target = torch.cat([x[net_name + 'target'] for x in one_outputs], dim=0) try: average_precision = metrics.average_precision_score( target.float().numpy(), out.float().numpy(), average=None) except ValueError: average_precision = np.array([np.nan] * 200) try: roc = metrics.roc_auc_score(target.numpy(), out.numpy(), average=None) except ValueError: roc = np.array([np.nan] * 200) logs = {set_name + net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(), set_name + net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(), set_name + net_name + 'roc': torch.as_tensor(roc.mean()).cuda(), 'step': torch.as_tensor(self.current_epoch).cuda()} # torch.save(average_precision, f"ap_perclass_{average_precision.mean()}.pt") # print(average_precision) self.log_dict(logs, sync_dist=True) if self.distributed_mode: allout = self.all_gather(out) alltarget = self.all_gather(target) average_precision = metrics.average_precision_score( alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(), allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None) if self.trainer.is_global_zero: logs = {set_name + net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(), 'step': torch.as_tensor(self.current_epoch).cuda()} self.log_dict(logs, sync_dist=False) else: self.log_dict( {set_name + net_name + "allap": logs[set_name + net_name + 'ap'], 'step': logs['step']}, sync_dist=True) def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) optimizer = get_optimizer(self.parameters()) # torch.optim.Adam(self.parameters(), lr=self.config.lr) return { 'optimizer': optimizer, 'lr_scheduler': get_lr_scheduler(optimizer) } def configure_callbacks(self): return get_extra_checkpoint_callback() + get_extra_swa_callback() @ex.command def get_dynamic_norm(model, dyn_norm=False): if not dyn_norm: return None, None raise RuntimeError('no dynamic norm supported yet.') model_checkpoint_callback = None @ex.command def get_extra_checkpoint_callback(save_best=None): if save_best is None: return [] global model_checkpoint_callback model_checkpoint_callback = ModelCheckpoint(monitor="allap", verbose=True, save_top_k=save_best, mode='max', every_n_val_epochs=1, every_n_train_steps=0) return [model_checkpoint_callback] @ex.command def get_extra_swa_callback(swa=True, swa_epoch_start=10, swa_freq=3): if not swa: return [] print("\n Using swa!\n") from helpers.swa_callback import StochasticWeightAveraging return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)] @ex.command def main(_run, _config, _log, _rnd, _seed): trainer = get_trainer(logger=get_logger()) train_loader = get_train_loader() val_loader = get_validate_loader() eval_loader = get_eval_loader() # eval_loader = get_eval_loader() modul = M(ex) trainer.fit( modul, train_dataloader=train_loader, val_dataloaders=[val_loader, eval_loader], ) ## evaluate best model on eval set #trainer.val_dataloaders = None modul.val_dataloaders = None return {"done": True} @ex.command def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100): ''' Test training speed of a model @param _run: @param _config: @param _log: @param _rnd: @param _seed: @param speed_test_batch_size: the batch size during the test @return: ''' modul = M(ex) modul = modul.cuda() batch_size = speed_test_batch_size print(f"\nBATCH SIZE : {batch_size}\n") test_length = 100 print(f"\ntest_length : {test_length}\n") x = torch.ones([batch_size, 1, 128, 998]).cuda() target = torch.ones([batch_size, 527]).cuda() # one passe net = modul.net # net(x) scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True # net = torch.jit.trace(net,(x,)) optimizer = torch.optim.SGD(net.parameters(), lr=0.001) print("warmup") import time torch.cuda.synchronize() t1 = time.time() for i in range(10): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('warmup done:', (t2 - t1)) torch.cuda.synchronize() t1 = time.time() print("testing speed") for i in range(test_length): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('test done:', (t2 - t1)) print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second") @ex.command def evaluate_only(_run, _config, _log, _rnd, _seed): # force overriding the config, not logged = not recommended trainer = get_trainer() train_loader = get_train_loader() val_loader = get_validate_loader() modul = M(ex) modul.val_dataloader = None #trainer.val_dataloaders = None print(f"\n\nValidation len={len(val_loader)}\n") res = trainer.validate(modul, dataloaders=val_loader) print("\n\n Validtaion:") print(res) @ex.command def test_loaders(): ''' get one sample from each loader for debbuging @return: ''' for i, b in enumerate(ex.datasets.training.get_iter()): print(b) break for i, b in enumerate(ex.datasets.test.get_iter()): print(b) break def set_default_json_pickle(obj): if isinstance(obj, set): return list(obj) raise TypeError @ex.command def preload_mp3(all_y=CMD("/basedataset.preload_mp3")): ''' read the dataset sequentially, useful if you have a network cache @param all_y: the dataset preload command @return: ''' print(all_y.shape) def multiprocessing_run(rank, word_size): print("rank ", rank, os.getpid()) print("word_size ", word_size) os.environ['NODE_RANK'] = str(rank) os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank] argv = sys.argv if rank != 0: print(f"Unobserved {os.getpid()} with rank {rank}") argv = argv + ["-u"] # only rank 0 is observed if "with" not in argv: argv = argv + ["with"] argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"] print(argv) @ex.main def default_command(): return main() ex.run_commandline(argv) if __name__ == '__main__': # set DDP=2 forks two processes to run on two GPUs # the environment variable "DDP" define the number of processes to fork # With two 2x 2080ti you can train the full model to .47 in around 24 hours # you may need to set NCCL_P2P_DISABLE=1 word_size = os.environ.get("DDP", None) if word_size: import random word_size = int(word_size) print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n") os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions os.environ['PL_IN_DDP_SUBPROCESS'] = '1' for rank in range(word_size): pid = os.fork() if pid == 0: print("Child Forked ") multiprocessing_run(rank, word_size) exit(0) pid, exit_code = os.wait() print(pid, exit_code) exit(0) print("__main__ is running pid", os.getpid(), "in module main: ", __name__) @ex.automain def default_command(): return main() ================================================ FILE: ex_openmic.py ================================================ import os import sys import torch from pytorch_lightning.callbacks import ModelCheckpoint from sacred.config_helpers import DynamicIngredient, CMD from torch.nn import functional as F import numpy as np from ba3l.experiment import Experiment from ba3l.module import Ba3lModule from torch.utils.data import DataLoader from config_updates import add_configs from helpers.mixup import my_mixup from helpers.models_size import count_non_zero_params from helpers.ramp import exp_warmup_linear_down, cosine_cycle from helpers.workersinit import worker_init_fn from sklearn import metrics from pytorch_lightning import Trainer as plTrainer from pytorch_lightning.loggers import WandbLogger ex = Experiment("openmic") # capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config get_trainer = ex.command(plTrainer, prefix="trainer") # capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line get_logger = ex.command(WandbLogger, prefix="wandb") # Example call with all the default config: # python ex_openmic.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base" # with 2 gpus: # DDP=2 python ex_openmic.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base" # define datasets and loaders get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=6, num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"), ) get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), validate=True, batch_size=20, num_workers=16, dataset=CMD("/basedataset.get_test_set")) @ex.config def default_conf(): cmd = " ".join(sys.argv) saque_cmd = os.environ.get("SAQUE_CMD", "").strip() saque_id = os.environ.get("SAQUE_ID", "").strip() slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip() if os.environ.get("SLURM_ARRAY_JOB_ID", False): slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID", "").strip() process_id = os.getpid() models = { "net": DynamicIngredient("models.passt.model_ing", n_classes=20, s_patchout_t=40, s_patchout_f=4), "mel": DynamicIngredient("models.preprocess.model_ing", instance_cmd="AugmentMelSTFT", n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10, fmax_aug_range=2000) } wandb = dict(project="passt_openmic", log_model=True) basedataset = DynamicIngredient("openmic.dataset.dataset", wavmix=1) # set the default for the trainer trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, reload_dataloaders_every_epoch=True) lr = 0.00001 use_mixup = True mixup_alpha = 0.3 # register extra possible configs add_configs(ex) @ex.command def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01, schedule_mode="exp_lin"): if schedule_mode == "exp_lin": return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value) if schedule_mode == "cos_cyc": return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.") @ex.command def get_lr_scheduler(optimizer, schedule_mode): if schedule_mode in {"exp_lin", "cos_cyc"}: return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda()) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") @ex.command def get_optimizer(params, lr, adamw=True, weight_decay=0.0001): if adamw: print(f"\nUsing adamw weight_decay={weight_decay}!\n") return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay) return torch.optim.Adam(params, lr=lr) class M(Ba3lModule): def __init__(self, experiment): self.mel = None self.da_net = None super(M, self).__init__(experiment) self.use_mixup = self.config.use_mixup or False self.mixup_alpha = self.config.mixup_alpha desc, sum_params, sum_non_zero = count_non_zero_params(self.net) self.experiment.info["start_sum_params"] = sum_params self.experiment.info["start_sum_params_non_zero"] = sum_non_zero # in case we need embedings for the DA self.net.return_embed = True self.dyn_norm = self.config.dyn_norm self.do_swa = False self.distributed_mode = self.config.trainer.num_nodes > 1 def forward(self, x): return self.net(x) def mel_forward(self, x): old_shape = x.size() x = x.reshape(-1, old_shape[2]) x = self.mel(x) x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) if self.dyn_norm: if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"): tr_m, tr_std = get_dynamic_norm(self) self.register_buffer('tr_m', tr_m) self.register_buffer('tr_std', tr_std) x = (x - self.tr_m) / self.tr_std return x def training_step(self, batch, batch_idx): # REQUIRED x, f, y = batch if self.mel: x = self.mel_forward(x) y_mask = y[:, 20:] y = y[:, :20] > 0.5 y = y.float() orig_x = x batch_size = len(y) rn_indices, lam = None, None if self.use_mixup: rn_indices, lam = my_mixup(batch_size, self.mixup_alpha) lam = lam.to(x.device) x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1)) y_hat, embed = self.forward(x) if self.use_mixup: y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1)) samples_loss = F.binary_cross_entropy_with_logits( y_hat, y_mix, reduction="none") y_mix_mask = ((y_mask > 0.5) | (y_mask[rn_indices] > 0.5)).float() samples_loss = y_mask.float() * samples_loss loss = samples_loss.mean() samples_loss = samples_loss.detach() else: samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") samples_loss = y_mask.float() * samples_loss loss = samples_loss.mean() samples_loss = samples_loss.detach() results = {"loss": loss, } return results def training_epoch_end(self, outputs): avg_loss = torch.stack([x['loss'] for x in outputs]).mean() logs = {'train.loss': avg_loss, 'step': self.current_epoch} self.log_dict(logs, sync_dist=True) def predict(self, batch, batch_idx: int, dataloader_idx: int = None): x, f, y = batch if self.mel: x = self.mel_forward(x) y_hat, _ = self.forward(x) return f, y_hat def validation_step(self, batch, batch_idx): x, f, y = batch if self.mel: x = self.mel_forward(x) y_mask = y[:, 20:] y = y[:, :20] > 0.5 y = y.float() results = {} model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: y_hat, _ = net(x) samples_loss = F.binary_cross_entropy_with_logits(y_hat, y) samples_loss = y_mask.float() * samples_loss loss = samples_loss.mean() out = torch.sigmoid(y_hat.detach()) # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False) results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach(), net_name + "mask": y_mask.detach()} results = {k: v.cpu() for k, v in results.items()} return results def validation_epoch_end(self, outputs): model_name = [("", self.net)] if self.do_swa: model_name = model_name + [("swa_", self.net_swa)] for net_name, net in model_name: avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean() out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0) target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0) mask = torch.cat([x[net_name + 'mask'] for x in outputs], dim=0) try: y_true = target.float().numpy() y_pred = out.float().numpy() y_mask = mask.float().numpy() average_precision = np.array([metrics.average_precision_score( y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])]) except ValueError: average_precision = np.array([np.nan] * y_true.shape[1]) #torch.save(average_precision, f"ap_openmic_perclass_{average_precision.mean()}.pt") try: roc = np.array([metrics.roc_auc_score( y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])]) except ValueError: roc = np.array([np.nan] * y_true.shape[1]) logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(), net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(), net_name + 'roc': torch.as_tensor(roc.mean()).cuda(), 'step': torch.as_tensor(self.current_epoch).cuda()} self.log_dict(logs, sync_dist=True) if self.distributed_mode: allout = self.all_gather(out) alltarget = self.all_gather(target) all_mask = self.all_gather(mask) y_true = alltarget.reshape(-1, alltarget.shape[-1]).cpu().float().numpy() y_pred = allout.reshape(-1, alltarget.shape[-1]).cpu().float().numpy() y_mask = all_mask.reshape(-1, alltarget.shape[-1]).cpu().float().numpy() average_precision = np.array([metrics.average_precision_score( y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])]) if self.trainer.is_global_zero: logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(), 'step': torch.as_tensor(self.current_epoch).cuda()} self.log_dict(logs, sync_dist=False) else: self.log_dict({net_name + "allap": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True) def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) optimizer = get_optimizer(self.parameters()) # torch.optim.Adam(self.parameters(), lr=self.config.lr) return { 'optimizer': optimizer, 'lr_scheduler': get_lr_scheduler(optimizer) } def configure_callbacks(self): return get_extra_checkpoint_callback() + get_extra_swa_callback() @ex.command def get_dynamic_norm(model, dyn_norm=False): if not dyn_norm: return None, None raise RuntimeError('no dynamic norm supported yet.') @ex.command def get_extra_checkpoint_callback(save_last_n=None): if save_last_n is None: return [] return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')] @ex.command def get_extra_swa_callback(swa=True, swa_epoch_start=2, swa_freq=1): if not swa: return [] print("\n Using swa!\n") from helpers.swa_callback import StochasticWeightAveraging return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)] @ex.command def main(_run, _config, _log, _rnd, _seed): trainer = get_trainer(logger=get_logger()) train_loader = get_train_loader() val_loader = get_validate_loader() modul = M(ex) trainer.fit( modul, train_dataloaders=train_loader, val_dataloaders=val_loader, ) return {"done": True} @ex.command def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100): ''' Test training speed of a model @param _run: @param _config: @param _log: @param _rnd: @param _seed: @param speed_test_batch_size: the batch size during the test @return: ''' modul = M(ex) modul = modul.cuda() batch_size = speed_test_batch_size print(f"\nBATCH SIZE : {batch_size}\n") test_length = 100 print(f"\ntest_length : {test_length}\n") x = torch.ones([batch_size, 1, 128, 998]).cuda() target = torch.ones([batch_size, 527]).cuda() # one passe net = modul.net # net(x) scaler = torch.cuda.amp.GradScaler() torch.backends.cudnn.benchmark = True # net = torch.jit.trace(net,(x,)) optimizer = torch.optim.SGD(net.parameters(), lr=0.001) print("warmup") import time torch.cuda.synchronize() t1 = time.time() for i in range(10): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('warmup done:', (t2 - t1)) torch.cuda.synchronize() t1 = time.time() print("testing speed") for i in range(test_length): with torch.cuda.amp.autocast(): y_hat, embed = net(x) loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() torch.cuda.synchronize() t2 = time.time() print('test done:', (t2 - t1)) print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second") @ex.command def evaluate_only(_run, _config, _log, _rnd, _seed): # force overriding the config, not logged = not recommended trainer = get_trainer() train_loader = get_train_loader() val_loader = get_validate_loader() modul = M(ex) modul.val_dataloader = None #trainer.val_dataloaders = None print(f"\n\nValidation len={len(val_loader)}\n") res = trainer.validate(modul, dataloaders=val_loader) print("\n\n Validtaion:") print(res) @ex.command def test_loaders(): ''' get one sample from each loader for debbuging @return: ''' for i, b in enumerate(ex.datasets.training.get_iter()): print(b) break for i, b in enumerate(ex.datasets.test.get_iter()): print(b) break def set_default_json_pickle(obj): if isinstance(obj, set): return list(obj) raise TypeError @ex.command def preload_mp3(all_y=CMD("/basedataset.preload_mp3")): ''' read the dataset sequentially, useful if you have a network cache @param all_y: the dataset preload command @return: ''' print(all_y.shape) def multiprocessing_run(rank, word_size): print("rank ", rank, os.getpid()) print("word_size ", word_size) os.environ['NODE_RANK'] = str(rank) os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank] argv = sys.argv if rank != 0: print(f"Unobserved {os.getpid()} with rank {rank}") argv = argv + ["-u"] # only rank 0 is observed if "with" not in argv: argv = argv + ["with"] argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"] print(argv) @ex.main def default_command(): return main() ex.run_commandline(argv) if __name__ == '__main__': # set DDP=2 forks two processes to run on two GPUs # the environment variable "DDP" define the number of processes to fork # With two 2x 2080ti you can train the full model to .47 in around 24 hours # you may need to set NCCL_P2P_DISABLE=1 word_size = os.environ.get("DDP", None) if word_size: import random word_size = int(word_size) print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n") os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions os.environ['PL_IN_DDP_SUBPROCESS'] = '1' for rank in range(word_size): pid = os.fork() if pid == 0: print("Child Forked ") multiprocessing_run(rank, word_size) exit(0) pid, exit_code = os.wait() print(pid, exit_code) exit(0) print("__main__ is running pid", os.getpid(), "in module main: ", __name__) @ex.automain def default_command(): return main() ================================================ FILE: fsd50k/README.md ================================================ # Experiments on FSD50K The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432)) consists of 51K audio clips annotated with 200 sound event classes taken from the Audioset ontology. The dataset contains 100 hours of audio and is the second largest publicly available general purpose sound event recognition dataset after Audioset. Furthermore, the FSD50K evaluation set is of high quality, with each evaluation label being double-checked and assessed by two to five independent annotators # Setup 1. Download the dataset from [Zenodo](https://zenodo.org/record/4060432) and unzip it. 2. Convert wav files to mp3s: ```shell cd fsd50k/prepare_scripts/ python convert_to_mp3.py path/to/fsd50k ``` this will create a folder inside the FSD50K directory with the mp3 files. 3. Pack the mp3 to HDF5 files: ```shell cd fsd50k/prepare_scripts/ python create_h5pymp3_dataset.py path/to/fsd50k ``` Now you should have inside `../../audioset_hdf5s/mp3/` three new files: `FSD50K.eval_mp3.hdf`, `FSD50K.val_mp3.hdf`, `FSD50K.train_mp3.hdf`. # Runing Experiments Similar to the runs on Audioset, PaSST-S: ```shell # Example call with all the default config: python ex_fsd50k.py with trainer.precision=16 -p ``` ```shell # Example call without overlap: python ex_fsd50k.py with passt_s_swa_p16_s16_128_ap473 models.net.s_patchout_t=10 models.net.s_patchout_f=1 trainer.precision=16 -p ``` # Pre-trained models Pre-trained models on FSD50K can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v0.0.5). In order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 : ```python from hear21passt.base import get_basic_model,get_model_passt import torch # model wrapper, includes Melspectrogram and the default pre-trained transformer model = get_basic_model(mode="logits") # replace the transformer with one that outputs 200 classes model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=200) # load the pre-trained model state dict with mAP of .655 on FSD50K state_dict = torch.load('/home/khaled/fsd50k-passt-s-f128-p16-s10-ap.655.pt') # load the weights into the transformer model.net.load_state_dict(state_dict) # example inference model.eval() model = model.cuda() with torch.no_grad(): # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k logits=model(audio_wave) ``` Using the model with no patch overlap PaSST-S-N `fsd50k-passt-s-n-f128-p16-s16-ap.642.pt`: ```python # replace the transformer with one that outputs 200 classes model.net = get_model_passt(arch="passt_s_p16_s16_128_ap468", fstride=16, tstride=16, n_classes=200) # load the pre-trained model state dict with mAP of .642 on FSD50K with no patch overlap state_dict = torch.load('/home/khaled/fsd50k-passt-s-n-f128-p16-s16-ap.642.pt') # load the weights into the transformer model.net.load_state_dict(state_dict) ``` ================================================ FILE: fsd50k/dataset.py ================================================ import io import os import pathlib import random import av import librosa import torchaudio from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler import torch from ba3l.ingredients.datasets import Dataset from sacred.config import DynamicIngredient, CMD from scipy.signal import convolve import numpy as np from helpers.audiodatasets import PreprocessDataset import h5py LMODE = os.environ.get("LMODE", False) # $TMPDIR dataset = Dataset('audiodataset') @dataset.config def default_config(): name = 'audioset' # dataset name normalize = False # normalize dataset subsample = False # subsample squares from the dataset roll = True # apply roll augmentation fold = 1 base_dir = "audioset_hdf5s/" # base directory of the dataset, change it or make a link if LMODE: base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/" balanced_train_hdf5 = base_dir + "mp3/FSD50K.train_mp3.hdf" valid_hdf5 = base_dir + "mp3/FSD50K.val_mp3.hdf" eval_hdf5 = base_dir + "mp3/FSD50K.eval_mp3.hdf" if LMODE: balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir) + "/") eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir) + "/") valid_hdf5 = valid_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir) + "/") ir_path = base_dir + "irs/" num_of_classes = 200 if LMODE: @dataset.config def LMODE_default_config(): cache_root_path = "/system/user/publicdata/CP/DCASE/cached_datasets/" def decode_mp3(mp3_arr): """ decodes an array if uint8 representing an mp3 file :rtype: np.array """ container = av.open(io.BytesIO(mp3_arr.tobytes())) stream = next(s for s in container.streams if s.type == 'audio') # print(stream) a = [] for i, packet in enumerate(container.demux(stream)): for frame in packet.decode(): a.append(frame.to_ndarray().reshape(-1)) waveform = np.concatenate(a) if waveform.dtype != 'float32': raise RuntimeError("Unexpected wave type") return waveform def pad_or_truncate(x, audio_length): """Pad all audio to specific length.""" if audio_length is None: # audio_length not specified don't do anything. return x if len(x) <= audio_length: return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) else: offset = torch.randint(0, len(x) - audio_length + 1, (1,)).item() return x[offset:offset + audio_length] irs_arr = None @dataset.command def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None): if not ir_augment: return global irs_arr if irs_arr is None: all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')] all_paths = sorted(all_paths) if cut_irs_offset is not None: all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10] all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths] print("will use these IRs:") for i in range(len(all_paths_name)): print(i, ": ", all_paths_name[i]) _run.info["ir_devices"] = all_paths_name irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths] return irs_arr[int(np.random.randint(0, len(irs_arr)))] @dataset.command def pydub_augment(waveform, gain_augment=7, ir_augment=0): if ir_augment and torch.rand(1) < ir_augment: ir = get_ir_sample() waveform = convolve(waveform, ir, 'full') if gain_augment: gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment amp = 10 ** (gain / 20) waveform = waveform * amp return waveform class MixupDataset(TorchDataset): """ Mixing Up wave forms """ def __init__(self, dataset, beta=2, rate=0.5): self.beta = beta self.rate = rate self.dataset = dataset print(f"Mixing up waveforms from dataset of len {len(dataset)}") def __getitem__(self, index): if torch.rand(1) < self.rate: x1, f1, y1 = self.dataset[index] idx2 = torch.randint(len(self.dataset), (1,)).item() x2, f2, y2 = self.dataset[idx2] l = np.random.beta(self.beta, self.beta) l = max(l, 1. - l) x1 = x1 - x1.mean() x2 = x2 - x2.mean() x = (x1 * l + x2 * (1. - l)) x = x - x.mean() return x, f1, (y1 * l + y2 * (1. - l)) return self.dataset[index] def __len__(self): return len(self.dataset) class AudioSetDataset(TorchDataset): def __init__(self, hdf5_file, sample_rate=32000, classes_num=200, clip_length=10, augment=False, in_mem=False): """ Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav """ self.sample_rate = sample_rate self.hdf5_file = hdf5_file if in_mem: print("\nPreloading in memory\n") with open(hdf5_file, 'rb') as f: self.hdf5_file = io.BytesIO(f.read()) with h5py.File(hdf5_file, 'r') as f: self.length = len(f['audio_name']) print(f"Dataset from {hdf5_file} with length {self.length}.") self.dataset_file = None # lazy init self.clip_length = clip_length if clip_length is not None: self.clip_length = clip_length * sample_rate self.classes_num = classes_num self.augment = augment if augment: print(f"Will agument data from {hdf5_file}") def open_hdf5(self): self.dataset_file = h5py.File(self.hdf5_file, 'r') def __len__(self): return self.length def __del__(self): if self.dataset_file is not None: self.dataset_file.close() self.dataset_file = None def __getitem__(self, index): """Load waveform and target of an audio clip. Args: meta: { 'hdf5_path': str, 'index_in_hdf5': int} Returns: data_dict: { 'audio_name': str, 'waveform': (clip_samples,), 'target': (classes_num,)} """ if self.dataset_file is None: self.open_hdf5() audio_name = self.dataset_file['audio_name'][index].decode() waveform = decode_mp3(self.dataset_file['mp3'][index]) if self.augment: waveform = pydub_augment(waveform) waveform = pad_or_truncate(waveform, self.clip_length) waveform = self.resample(waveform) target = self.dataset_file['target'][index] target = np.unpackbits(target, axis=-1, count=self.classes_num).astype(np.float32) return waveform.reshape(1, -1), audio_name, target def resample(self, waveform): """Resample. Args: waveform: (clip_samples,) Returns: (resampled_clip_samples,) """ if self.sample_rate == 32000: return waveform elif self.sample_rate == 16000: return waveform[0:: 2] elif self.sample_rate == 8000: return waveform[0:: 4] else: raise Exception('Incorrect sample rate!') @dataset.command def get_base_training_set(balanced_train_hdf5, clip_length=10): ds = AudioSetDataset(balanced_train_hdf5, augment=True, clip_length=clip_length) return ds @dataset.command def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes): # Preload mp3 sequential from disk, OS will cache the chunks in memory. # Useful if the hdf file is on a NFS mount, saving the random access. for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]: print(f"\n \n will now preload {hdf5_file} \n\n ") with h5py.File(hdf5_file, 'r') as dataset_file: target = dataset_file['mp3'][:] print(len(target)) print(f"\n \n done with {hdf5_file} \n\n ") return target[1000] @dataset.command def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"), epoch_len=100000, sampler_replace=False): num_nodes = int(os.environ.get('num_nodes', 1)) ddp = int(os.environ.get('DDP', 1)) num_nodes = max(ddp, num_nodes) print("num_nodes= ", num_nodes) rank = int(os.environ.get('NODE_RANK', 0)) return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights, num_samples=epoch_len, replacement=sampler_replace), dataset=range(epoch_len), num_replicas=num_nodes, rank=rank, ) @dataset.command def get_base_eval_set(eval_hdf5, variable_eval=None): if variable_eval: print("Variable length eval!!") ds = AudioSetDataset(eval_hdf5, clip_length=None) else: ds = AudioSetDataset(eval_hdf5) return ds @dataset.command def get_base_valid_set(valid_hdf5, variable_eval=None): if variable_eval: print("Variable length valid_set !!") ds = AudioSetDataset(valid_hdf5, clip_length=None) else: ds = AudioSetDataset(valid_hdf5) return ds @dataset.command(prefix='roll_conf') def get_roll_func(axis=1, shift=None, shift_range=50): print("rolling...") def roll_func(b): x, i, y = b x = torch.as_tensor(x) sf = shift if shift is None: sf = int(np.random.random_integers(-shift_range, shift_range)) global FirstTime return x.roll(sf, axis), i, y return roll_func @dataset.command def get_training_set(normalize, roll, wavmix=False): ds = get_base_training_set() get_ir_sample() if normalize: print("normalized train!") fill_norms() ds = PreprocessDataset(ds, norm_func) if roll: ds = PreprocessDataset(ds, get_roll_func()) if wavmix: ds = MixupDataset(ds) return ds @dataset.command def get_valid_set(normalize): ds = get_base_valid_set() if normalize: print("normalized test!") fill_norms() ds = PreprocessDataset(ds, norm_func) return ds @dataset.command def get_eval_set(normalize): ds = get_base_eval_set() if normalize: print("normalized test!") fill_norms() ds = PreprocessDataset(ds, norm_func) return ds @dataset.command def print_conf(_config): print("Config of ", dataset.path, id(dataset)) print(_config) print() class DistributedSamplerWrapper(DistributedSampler): def __init__( self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True): super(DistributedSamplerWrapper, self).__init__( dataset, num_replicas, rank, shuffle) # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238 self.sampler = sampler def __iter__(self): if self.sampler.generator is None: self.sampler.generator = torch.Generator() self.sampler.generator.manual_seed(self.seed + self.epoch) indices = list(self.sampler) if self.epoch == 0: print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n") indices = indices[self.rank:self.total_size:self.num_replicas] return iter(indices) if __name__ == "__main__": from sacred import Experiment ex = Experiment("test_dataset", ingredients=[dataset]) @ex.automain def default_command(): ex.current_run.get_command_function("print_config")() get_base_training_set() ds = get_test_set() print(ds[0]) ds = get_training_set() print(ds[0]) print("get_base_training_set", len(get_base_training_set())) print("get_base_test_set", len(get_base_test_set())) print("get_training_set", len(get_training_set())) print("get_test_set", len(get_test_set())) ================================================ FILE: fsd50k/prepare_scripts/convert_to_mp3.py ================================================ import multiprocessing import glob import os import sys if len(sys.argv) > 1: FSD50K_base = sys.argv[1] # the path to of FSD50K base as downloaded from zalando. else: FSD50K_base = "/home/khaled/shared/FSD50K/" # the path to of FSD50K base as downloaded from zalando. print("Pass the path to FSD50K: python convert_to_mp3.py path/to/fsd50k") outputp = FSD50K_base + "/mp3/" # the path to the output mp3. all_num = 0 def process_folder(fol="balanced_train_segments"): print("now working on ", fol) os.makedirs(outputp + fol, exist_ok=True) all_files = list(glob.glob(FSD50K_base + fol + "/*.wav")) print(f"it has {len(all_files)}") global all_num all_num = len(all_files) cmds = [(i, file, outputp + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)] print(cmds[0]) with multiprocessing.Pool(processes=20) as pool: pool.starmap(process_one, cmds) def process_one(i, f1, f2): if i % 100 == 0: print(f"{i}/{all_num} \t", f1) os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3") print("We will convert the following folders to mp3: ") folders = ['FSD50K.eval_audio', 'FSD50K.dev_audio'] print(folders) for fol in folders: process_folder(fol) ================================================ FILE: fsd50k/prepare_scripts/create_h5pymp3_dataset.py ================================================ # %% import sys import h5py import pandas as pd import numpy as np import csv import os # %% from numpy import dtype if len(sys.argv) > 1: FSD50K_base = sys.argv[1] # the path to of FSD50K base as downloaded from zalando. else: FSD50K_base = "/home/khaled/shared/FSD50K/" # the path to of FSD50K base as downloaded from zalando. print("Pass the path to FSD50K: python convert_to_mp3.py path/to/fsd50k") base_dir = "../../audioset_hdf5s/" # the path to store hdf file. #### balanced_csv = FSD50K_base + "FSD50K.ground_truth/dev.csv" eval_csv = FSD50K_base + "FSD50K.ground_truth/eval.csv" class_idx_csv = FSD50K_base + "FSD50K.ground_truth/vocabulary.csv" mp3_path = "/home/khaled/shared/FSD50K/mp3/" # %% df = pd.read_csv(class_idx_csv, header=None, index_col=0) classes_list = list(df[1].values) assert sorted(classes_list) == classes_list id_to_ix = {id: i for i, id in enumerate(classes_list)} ix_to_id = {i: id for i, id in enumerate(classes_list)} # %% # Load labels df = pd.read_csv(balanced_csv) train = df[df.split == "train"] val = df[df.split == "val"] eval = pd.read_csv(eval_csv) # %% def get_labels(df): y = np.zeros((len(df), 200), dtype=np.int32) for i, target in enumerate(df.labels.values): for t in target.split(","): y[i, id_to_ix[t]] = 1 return df.fname.values, y # %% for set_name, df, prefix in [("train", train, "FSD50K.dev_audio/"), ("val", val, "FSD50K.dev_audio/"), ("eval", eval, "FSD50K.eval_audio/")]: print("now working on ", set_name, prefix, "len=", len(df)) # files, y = torch.load(read_file+".pth") files, y = get_labels(df) y = np.packbits(y, axis=-1) packed_len = y.shape[1] print(files[0], "classes: ", packed_len, y.dtype) available_size = len(files) dt = h5py.vlen_dtype(np.dtype('uint8')) save_file = "FSD50K." + set_name if os.path.isfile(base_dir + "mp3/" + save_file + "_mp3.hdf"): print(base_dir + "mp3/" + save_file + "_mp3.hdf", "exists!\n\n\n contiue") continue with h5py.File(base_dir + "mp3/" + save_file + "_mp3.hdf", 'w') as hf: audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20') waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt) target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype) for i, file in enumerate(files): if i % 1000 == 0: print(f"{i}/{available_size}") f = f"{file}.mp3" a = np.fromfile(mp3_path + prefix + f, dtype='uint8') audio_name[i] = f waveform[i] = a target[i] = y[i] print(a.shape) print("Done!", prefix) ================================================ FILE: helpers/audiodatasets.py ================================================ import hashlib import os import time import torch from torch.utils.data import Dataset from os.path import expanduser import logging def h6(w): return hashlib.md5(w.encode('utf-8')).hexdigest()[:6] class AudioPreprocessDataset(Dataset): """A bases preprocessing dataset representing a Dataset of files that are loaded and preprossessed on the fly. Access elements via __getitem__ to return: preprocessor(x),sample_id,label supporting integer indexing in range from 0 to len(self) exclusive. """ def __init__(self, files, labels, label_encoder, base_dir, preprocessor, return_tensor=True, ordered_ids=None): self.files = files if ordered_ids is None: ordered_ids = files else: print("AudioPreprocessDataset: ordered_ids is not None using it instead of files !!!") self.ordered_ids = ordered_ids self.labels = labels self.label_encoder = label_encoder self.base_dir = base_dir self.preprocessor = preprocessor self.return_tensor = return_tensor def __getitem__(self, index): x = self.preprocessor(self.base_dir + self.files[index]) if self.return_tensor and not isinstance(x, torch.Tensor): x = torch.from_numpy(x) return x, self.ordered_ids[index], self.labels[index] def get_ordered_ids(self): return self.ordered_ids def get_ordered_labels(self): return self.labels def __len__(self): return len(self.ordered_ids) class ObjectCacher: def __init__(self, get_obj_func, dataset_name, obj_name="", cache_path="~/shared/kofta_cached_datasets/", verbose=True): self.dataset_name = dataset_name self.obj_name = obj_name cache_path = expanduser(cache_path) self.cache_path = os.path.join(cache_path, dataset_name) try: startTime = time.time() xpath = self.get_obj_cache_path() if verbose: logging.info( "attempting to load x from cache at " + xpath + "...") self.obj = torch.load(xpath) if verbose: endTime = time.time() logging.info( "loaded " + xpath + " from cache in %s " % (endTime - startTime)) except IOError: if verbose: logging.info( "Invalid cache " + xpath + " , recomputing") self.obj = get_obj_func() saveStartTime = time.time() dirpath=os.path.dirname(xpath) try: original_umask = os.umask(0) os.makedirs(dirpath, exist_ok=True) finally: os.umask(original_umask) torch.save(self.obj, xpath) if verbose: endTime = time.time() logging.info( "loaded " + obj_name + " in %s, and cached in %s, total %s seconds " % ( (saveStartTime - startTime), (endTime - saveStartTime), (endTime - startTime))) def get_obj_cache_path(self): return os.path.join(self.cache_path, self.obj_name + "_obj.pt") def get(self): return self.obj class PreprocessDataset(Dataset): """A bases preprocessing dataset representing a preprocessing step of a Dataset preprossessed on the fly. supporting integer indexing in range from 0 to len(self) exclusive. """ def __init__(self, dataset, preprocessor): self.dataset = dataset if not callable(preprocessor): print("preprocessor: ", preprocessor) raise ValueError('preprocessor should be callable') self.preprocessor = preprocessor def __getitem__(self, index): return self.preprocessor(self.dataset[index]) def __len__(self): return len(self.dataset) class FilesCachedDataset(Dataset): def __init__(self, get_dataset_func, dataset_name, x_name="", cache_path="~/shared/kofta_cached_datasets/", ): """ Cached the dataset in small torch.save files (1 file per sample). The dataset is suitable for SSDs being used bcache from a slow harddrive with a small @param get_dataset_func: fuction gets called if the file cache is invalid @param dataset_name: the folder containing the dataset @param x_name: tag for the version @param cache_path: cache_path """ self.dataset = None def getDataset(): if self.dataset == None: self.dataset = get_dataset_func() return self.dataset self.get_dataset_func = getDataset self.x_name = x_name cache_path = expanduser(cache_path) self.cache_path = os.path.join(cache_path, dataset_name, "files_cache", self.x_name) try: original_umask = os.umask(0) os.makedirs(self.cache_path, exist_ok=True) finally: os.umask(original_umask) def __getitem__(self, index): cpath = os.path.join(self.cache_path, str(index) + ".pt") try: return torch.load(cpath) except FileNotFoundError: tup = self.get_dataset_func()[index] torch.save(tup, cpath) return tup def get_ordered_labels(self): return self.get_dataset_func().get_ordered_labels() def get_ordered_ids(self): return self.get_dataset_func().get_ordered_ids() def get_xcache_path(self): return os.path.join(self.cache_path, self.x_name + "_x.pt") def get_ycache_path(self): return os.path.join(self.cache_path, self.y_name + "_y.pt") def get_sidcache_path(self): return os.path.join(self.cache_path, self.y_name + "_sid.pt") def __len__(self): return len(self.get_dataset_func()) class SelectionDataset(Dataset): """A dataset that selects a subsample from a dataset based on a set of sample ids. supporting integer indexing in range from 0 to len(self) exclusive. """ def __init__(self, dataset, sample_ids): self.available_indexes = [] self.dataset = dataset self.reselect(sample_ids) self.sample_ids = sample_ids def reselect(self, sample_ids): reverse_dict = dict([(sid, i) for i, sid in enumerate(self.dataset.get_ordered_ids())]) self.available_indexes = [reverse_dict[sid] for sid in sample_ids] def get_ordered_ids(self): return self.sample_ids def get_ordered_labels(self): labels=self.dataset.get_ordered_labels() return [labels[i] for i in self.available_indexes] #raise NotImplementedError("Maybe reconsider caching only a selection Dataset. why not select after cache?") def __getitem__(self, index): return self.dataset[self.available_indexes[index]] def __len__(self): return len(self.available_indexes) class SimpleSelectionDataset(Dataset): """A dataset that selects a subsample from a dataset based on a set of sample ids. supporting integer indexing in range from 0 to len(self) exclusive. """ def __init__(self, dataset, available_indexes ): self.available_indexes = available_indexes self.dataset = dataset def __getitem__(self, index): return self.dataset[self.available_indexes[index]] def __len__(self): return len(self.available_indexes) ================================================ FILE: helpers/mixup.py ================================================ import numpy as np import torch def my_mixup(size, alpha): rn_indices = torch.randperm(size) lambd = np.random.beta(alpha, alpha, size).astype(np.float32) lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1) lam = torch.FloatTensor(lambd) # data = data * lam + data2 * (1 - lam) # targets = targets * lam + targets2 * (1 - lam) return rn_indices, lam ================================================ FILE: helpers/models_size.py ================================================ def count_non_zero_params(model): sum_params = 0 sum_non_zero = 0 desc = "" def calc_params(model): nonlocal desc, sum_params, sum_non_zero skip = "" if "batchnorm" in type(model).__name__.lower(): for k,p in [("running_mean", model.running_mean), ("running_var", model.running_var)]: nonzero = p[p != 0].numel() total = p.numel() desc += f"type {type(model).__name__}, {k}, {total}, {nonzero}, {p.dtype}, {skip} " + "\n" if skip != "skip": sum_params += total sum_non_zero += nonzero for k, p in model.named_parameters(recurse=False): nonzero = p[p != 0].numel() total = p.numel() desc += f"type {type(model).__name__}, {k}, {total}, {nonzero}, {p.dtype}, {skip} " + "\n" if skip != "skip": sum_params += total sum_non_zero += nonzero model.apply(calc_params) return desc, sum_params, sum_non_zero ================================================ FILE: helpers/ramp.py ================================================ import numpy as np from ba3l.ingredients.ingredient import Ingredient # credit: https://github.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorch/blob/master/utils/ramps.py def pseudo_rampup(T1, T2): def warpper(epoch): if epoch > T1: alpha = (epoch - T1) / (T2 - T1) if epoch > T2: alpha = 1.0 else: alpha = 0.0 return alpha return warpper def exp_rampup(rampup_length): """Exponential rampup from https://arxiv.org/abs/1610.02242""" def warpper(epoch): if epoch < rampup_length: epoch = np.clip(epoch, 0.5, rampup_length) phase = 1.0 - epoch / rampup_length return float(np.exp(-5.0 * phase * phase)) else: return 1.0 return warpper def linear_rampup(rampup_length): """Linear rampup""" def warpper(epoch): if epoch < rampup_length: return epoch / rampup_length else: return 1.0 return warpper def linear_rampdown(rampdown_length, start=0, last_value=0): """Linear rampup -(start)- (rampdown_length) \ _(for the rest) """ def warpper(epoch): if epoch <= start: return 1. elif epoch - start < rampdown_length: return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length else: return last_value return warpper def exp_rampdown(rampdown_length, num_epochs): """Exponential rampdown from https://arxiv.org/abs/1610.02242""" def warpper(epoch): if epoch >= (num_epochs - rampdown_length): ep = .5 * (epoch - (num_epochs - rampdown_length)) return float(np.exp(-(ep * ep) / rampdown_length)) else: return 1.0 return warpper def cosine_rampdown(rampdown_length, num_epochs): """Cosine rampdown from https://arxiv.org/abs/1608.03983""" def warpper(epoch): if epoch >= (num_epochs - rampdown_length): ep = .5 * (epoch - (num_epochs - rampdown_length)) return float(.5 * (np.cos(np.pi * ep / rampdown_length) + 1)) else: return 1.0 return warpper def exp_warmup(rampup_length, rampdown_length, num_epochs): rampup = exp_rampup(rampup_length) rampdown = exp_rampdown(rampdown_length, num_epochs) def warpper(epoch): return rampup(epoch) * rampdown(epoch) return warpper def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value): rampup = exp_rampup(warmup) rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value) def warpper(epoch): return rampup(epoch) * rampdown(epoch) return warpper def test_warmup(): warmup = exp_warmup(20, 100, 150) for ep in range(500): print(warmup(ep)) def test_warmupl(): warmup = exp_warmup_linear_down(20, 100, 50, 0.001) for ep in range(500): print(warmup(ep)) def cosine_cycle(cycle_len=20,ramp_down_start=100,last_lr_value=0.01): """Cosine rampdown from https://arxiv.org/abs/1608.03983""" ramp_down_start = cycle_len+ (ramp_down_start-1)//cycle_len*(cycle_len) print("adjusted ramp_down_start:",ramp_down_start) def warpper(epoch): ep = (epoch+cycle_len//2.)/(1.*cycle_len) if epoch>ramp_down_start: return last_lr_value return float(last_lr_value + (1.-last_lr_value)* .5 * (np.cos(2.*np.pi * ep) + 1)) return warpper if __name__ == '__main__': test= exp_warmup_linear_down(20, 100, 50, 150) for i in range(250): print(test(i)) ================================================ FILE: helpers/swa_callback.py ================================================ # Adapted from PyTorch Lightning so that it only does the averaging # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r""" Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy from typing import Callable, Optional, Union import torch from torch import nn import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.optim.swa_utils import SWALR _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] class StochasticWeightAveraging(Callback): def __init__( self, swa_epoch_start: Union[int, float] = 0.8, swa_freq: Union[int, float] = 3, swa_lrs: Optional[Union[float, list]] = None, annealing_epochs: int = 10, annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, device: Optional[Union[torch.device, str]] = None, ): r""" Implements the Stochastic Weight Averaging (SWA) Callback to average a model. Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018). This documentation is highly inspired by PyTorch's work on SWA. The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. For a SWA explanation, please take a look `here `_. .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change. .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. SWA can easily be activated directly from the Trainer as follow: .. code-block:: python Trainer(stochastic_weight_avg=True) Arguments: swa_epoch_start: If provided as int, the procedure will start from the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch swa_lrs: the learning rate value for all param groups together or separately for each group. annealing_epochs: number of epochs in the annealing phase (default: 10) annealing_strategy: Specifies the annealing strategy (default: "cos"): - ``"cos"``. For cosine annealing. - ``"linear"`` For linear annealing avg_fn: the averaging function used to update the parameters; the function must take in the current value of the :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: ``None``) device: if provided, the averaged model will be stored on the ``device``. When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) """ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: raise MisconfigurationException(err_msg) if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): raise MisconfigurationException(err_msg) wrong_type = not isinstance(swa_lrs, (float, list)) wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)): raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") if avg_fn is not None and not isinstance(avg_fn, Callable): raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") self.swa_freq = swa_freq self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device self._model_contains_batch_norm = None self._average_model = None @property def swa_start(self) -> int: return max(self._swa_epoch_start - 1, 0) # 0-based @property def swa_end(self) -> int: return self._max_epochs - 1 # 0-based @staticmethod def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'): return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage:str): # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module.net) def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if len(trainer.optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) self._max_epochs = trainer.max_epochs if self._model_contains_batch_norm: print("\n\n_model_contains_batch_norm\n\n") # virtually increase max_epochs to perform batch norm update on latest epoch. trainer.max_epochs += 1 def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if trainer.current_epoch == self.swa_start: print(f"\n\n SWA START at {trainer.current_epoch}\n\n") # move average model to request device. self._average_model = self._average_model.to(self._device or pl_module.device) optimizers = trainer.optimizers for param_group in optimizers[0].param_groups: if self._swa_lrs is None: initial_lr = param_group["lr"] elif isinstance(self._swa_lrs, float): initial_lr = self._swa_lrs else: initial_lr = self._swa_lrs[0] param_group["initial_lr"] = initial_lr self._swa_lrs = initial_lr self._swa_scheduler = SWALR( optimizers[0], swa_lr=initial_lr, anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1 ) self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device) pl_module.net_swa = self._average_model if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0): self.update_parameters(self._average_model, pl_module.net, self.n_averaged, self.avg_fn) pl_module.net_swa = self._average_model def on_train_epoch_end(self, trainer: 'pl.Trainer',pl_module: 'pl.LightningModule', *args): trainer.fit_loop._skip_backward = False def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): pass def on_validation_epoch_start(self, trainer, pl_module) -> None: """Called when the val epoch begins.""" if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0): pl_module.do_swa= True else: pl_module.do_swa = False @staticmethod def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'): for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) def reset_batch_norm_and_save_state(self, pl_module): """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154 """ self.momenta = {} for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue module.running_mean = torch.zeros_like( module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype ) module.running_var = torch.ones_like( module.running_var, device=pl_module.device, dtype=module.running_var.dtype ) self.momenta[module] = module.momentum module.momentum = None module.num_batches_tracked *= 0 def reset_momenta(self): """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165 """ for bn_module in self.momenta.keys(): bn_module.momentum = self.momenta[bn_module] @staticmethod def update_parameters( average_model, model, n_averaged: torch.LongTensor, avg_fn: _AVG_FN ): """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112 """ for p_swa, p_model in zip(average_model.parameters(), model.parameters()): device = p_swa.device p_swa_ = p_swa.detach() p_model_ = p_model.detach().to(device) src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) p_swa_.copy_(src) n_averaged += 1 @staticmethod def avg_fn( averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor ) -> torch.FloatTensor: """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97 """ return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) ================================================ FILE: helpers/swa_legacy.py ================================================ # Adapted from PyTorch Lightning so that it only does the averaging # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r""" Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ """ from copy import deepcopy from typing import Callable, Optional, Union import torch from torch import nn import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.optim.swa_utils import SWALR _AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor] class StochasticWeightAveraging(Callback): def __init__( self, swa_epoch_start: Union[int, float] = 0.8, swa_freq: Union[int, float] = 3, swa_lrs: Optional[Union[float, list]] = None, annealing_epochs: int = 10, annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, device: Optional[Union[torch.device, str]] = None, ): r""" Implements the Stochastic Weight Averaging (SWA) Callback to average a model. Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018). This documentation is highly inspired by PyTorch's work on SWA. The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. For a SWA explanation, please take a look `here `_. .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change. .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. SWA can easily be activated directly from the Trainer as follow: .. code-block:: python Trainer(stochastic_weight_avg=True) Arguments: swa_epoch_start: If provided as int, the procedure will start from the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch swa_lrs: the learning rate value for all param groups together or separately for each group. annealing_epochs: number of epochs in the annealing phase (default: 10) annealing_strategy: Specifies the annealing strategy (default: "cos"): - ``"cos"``. For cosine annealing. - ``"linear"`` For linear annealing avg_fn: the averaging function used to update the parameters; the function must take in the current value of the :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: ``None``) device: if provided, the averaged model will be stored on the ``device``. When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) """ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: raise MisconfigurationException(err_msg) if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): raise MisconfigurationException(err_msg) wrong_type = not isinstance(swa_lrs, (float, list)) wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 wrong_list = isinstance(swa_lrs, list) and not all( lr > 0 and isinstance(lr, float) for lr in swa_lrs) if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)): raise MisconfigurationException( "The `swa_lrs` should be a positive float or a list of positive float.") if avg_fn is not None and not isinstance(avg_fn, Callable): raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException( f"device is expected to be a torch.device or a str. Found {device}") self.swa_freq = swa_freq self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device self._model_contains_batch_norm = None self._average_model = None @property def swa_start(self) -> int: return max(self._swa_epoch_start - 1, 0) # 0-based @property def swa_end(self) -> int: return self._max_epochs - 1 # 0-based @staticmethod def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'): return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()) def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module.net) def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): optimizers = trainer.optimizers lr_schedulers = trainer.lr_schedulers if len(optimizers) != 1: raise MisconfigurationException( "SWA currently works with 1 `optimizer`.") if len(lr_schedulers) > 1: raise MisconfigurationException( "SWA currently not supported for more than 1 `lr_scheduler`.") if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int( trainer.max_epochs * self._swa_epoch_start) self._model_contains_batch_norm = self.pl_module_contains_batch_norm( pl_module) self._max_epochs = trainer.max_epochs if self._model_contains_batch_norm: print("\n\n_model_contains_batch_norm\n\n") # virtually increase max_epochs to perform batch norm update on latest epoch. trainer.max_epochs += 1 def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): if trainer.current_epoch == self.swa_start: print(f"\n\n SWA START at {trainer.current_epoch}\n\n") # move average model to request device. self._average_model = self._average_model.to( self._device or pl_module.device) optimizers = trainer.optimizers for param_group in optimizers[0].param_groups: if self._swa_lrs is None: initial_lr = param_group["lr"] elif isinstance(self._swa_lrs, float): initial_lr = self._swa_lrs else: initial_lr = self._swa_lrs[0] param_group["initial_lr"] = initial_lr self._swa_lrs = initial_lr self._swa_scheduler = SWALR( optimizers[0], swa_lr=initial_lr, anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1 ) self.n_averaged = torch.tensor( 0, dtype=torch.long, device=pl_module.device) pl_module.net_swa = self._average_model if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0): self.update_parameters(self._average_model, pl_module.net, self.n_averaged, self.avg_fn) pl_module.net_swa = self._average_model def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', *args): trainer.train_loop._skip_backward = False def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'): pass def on_validation_epoch_start(self, trainer, pl_module) -> None: """Called when the val epoch begins.""" if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0): pl_module.do_swa = True else: pl_module.do_swa = False @staticmethod def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'): for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device)) def reset_batch_norm_and_save_state(self, pl_module): """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154 """ self.momenta = {} for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue module.running_mean = torch.zeros_like( module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype ) module.running_var = torch.ones_like( module.running_var, device=pl_module.device, dtype=module.running_var.dtype ) self.momenta[module] = module.momentum module.momentum = None module.num_batches_tracked *= 0 def reset_momenta(self): """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165 """ for bn_module in self.momenta.keys(): bn_module.momentum = self.momenta[bn_module] @staticmethod def update_parameters( average_model, model, n_averaged: torch.LongTensor, avg_fn: _AVG_FN ): """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112 """ for p_swa, p_model in zip(average_model.parameters(), model.parameters()): device = p_swa.device p_swa_ = p_swa.detach() p_model_ = p_model.detach().to(device) src = p_model_ if n_averaged == 0 else avg_fn( p_swa_, p_model_, n_averaged.to(device)) p_swa_.copy_(src) n_averaged += 1 @staticmethod def avg_fn( averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor ) -> torch.FloatTensor: """ Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97 """ return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) ================================================ FILE: helpers/workersinit.py ================================================ import torch import numpy as np import random def worker_init_fn(x): seed = (torch.initial_seed() + x * 1000) % 2 ** 31 # problem with nearly seeded randoms np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) return ================================================ FILE: models/helpers/vit_helpers.py ================================================ """ Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py Credit to @leo19941227 for remove timm dependencies here : https://github.com/s3prl/passt_hear21/blob/48a0dc1b824641ca59884ced53f5b86053fed141/hear21passt/models/helpers/vit_helpers.py """ import math import logging import warnings from copy import deepcopy import torch from torch import nn try: from timm.models._hub import download_cached_file except ModuleNotFoundError: from timm.models.hub import download_cached_file # Global variables for rarely used pretrained checkpoint download progress and hash check. # Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. _DOWNLOAD_PROGRESS = True _CHECK_HASH = False _logger = logging.getLogger(__name__) def adapt_input_conv(in_chans, conv_weight): conv_type = conv_weight.dtype conv_weight = ( conv_weight.float() ) # Some weights are in torch.half, ensure it's float for sum on CPU O, I, J, K = conv_weight.shape if in_chans == 1: if I > 3: assert conv_weight.shape[1] % 3 == 0 # For models with space2depth stems conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) conv_weight = conv_weight.sum(dim=2, keepdim=False) else: conv_weight = conv_weight.sum(dim=1, keepdim=True) elif in_chans != 3: if I != 3: raise NotImplementedError("Weight format not supported by conversion.") else: # NOTE this strategy should be better than random init, but there could be other combinations of # the original RGB input layer weights that'd work better for specific cases. repeat = int(math.ceil(in_chans / 3)) conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] conv_weight *= 3 / float(in_chans) conv_weight = conv_weight.to(conv_type) return conv_weight def load_pretrained( model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False, ): """Load pretrained checkpoint Args: model (nn.Module) : PyTorch model module default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset num_classes (int): num_classes for model in_chans (int): in_chans for model filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) strict (bool): strict load of checkpoint progress (bool): enable progress bar for weight download """ default_cfg = default_cfg or getattr(model, "default_cfg", None) or {} pretrained_url = default_cfg.get("url", None) if not pretrained_url: _logger.warning( "No pretrained weights exist for this model. Using random initialization." ) return _logger.info(f"Loading pretrained weights from url ({pretrained_url})") pretrained_loc = download_cached_file( pretrained_url, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS, ) state_dict = torch.load(pretrained_loc, map_location="cpu") if filter_fn is not None: # for backwards compat with filter fn that take one arg, try one first, the two try: state_dict = filter_fn(state_dict) except TypeError: state_dict = filter_fn(state_dict, model) input_convs = default_cfg.get("first_conv", None) if input_convs is not None and in_chans != 3: if isinstance(input_convs, str): input_convs = (input_convs,) for input_conv_name in input_convs: weight_name = input_conv_name + ".weight" try: state_dict[weight_name] = adapt_input_conv( in_chans, state_dict[weight_name] ) _logger.info( f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)" ) except NotImplementedError as e: del state_dict[weight_name] strict = False _logger.warning( f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer." ) classifiers = default_cfg.get("classifier", None) label_offset = default_cfg.get("label_offset", 0) if classifiers is not None: if isinstance(classifiers, str): classifiers = (classifiers,) if num_classes != default_cfg["num_classes"]: for classifier_name in classifiers: # completely discard fully connected if model num_classes doesn't match pretrained weights del state_dict[classifier_name + ".weight"] del state_dict[classifier_name + ".bias"] strict = False elif label_offset > 0: for classifier_name in classifiers: # special case for pretrained weights with an extra background class in pretrained weights classifier_weight = state_dict[classifier_name + ".weight"] state_dict[classifier_name + ".weight"] = classifier_weight[ label_offset: ] classifier_bias = state_dict[classifier_name + ".bias"] state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:] model.load_state_dict(state_dict, strict=strict) def overlay_external_default_cfg(default_cfg, kwargs): """Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.""" external_default_cfg = kwargs.pop("external_default_cfg", None) if external_default_cfg: default_cfg.pop("url", None) # url should come from external cfg default_cfg.pop("hf_hub", None) # hf hub id should come from external cfg default_cfg.update(external_default_cfg) def filter_kwargs(kwargs, names): if not kwargs or not names: return for n in names: kwargs.pop(n, None) def set_default_kwargs(kwargs, names, default_cfg): for n in names: # for legacy reasons, model __init__args uses img_size + in_chans as separate args while # default_cfg has one input_size=(C, H ,W) entry if n == "img_size": input_size = default_cfg.get("input_size", None) if input_size is not None: assert len(input_size) == 3 kwargs.setdefault(n, input_size[-2:]) elif n == "in_chans": input_size = default_cfg.get("input_size", None) if input_size is not None: assert len(input_size) == 3 kwargs.setdefault(n, input_size[0]) else: default_val = default_cfg.get(n, None) if default_val is not None: kwargs.setdefault(n, default_cfg[n]) def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): """Update the default_cfg and kwargs before passing to model FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs could/should be replaced by an improved configuration mechanism Args: default_cfg: input default_cfg (updated in-place) kwargs: keyword args passed to model build fn (updated in-place) kwargs_filter: keyword arg keys that must be removed before model __init__ """ # Overlay default cfg values from `external_default_cfg` if it exists in kwargs overlay_external_default_cfg(default_cfg, kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) default_kwarg_names = ("num_classes", "global_pool", "in_chans") if default_cfg.get("fixed_input_size", False): # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size default_kwarg_names += ("img_size",) set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) # Filter keyword args for task specific model variants (some 'features only' models, etc.) filter_kwargs(kwargs, names=kwargs_filter) def drop_path(x, drop_prob: float = 0.0, training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) from torch.nn.init import _calculate_fan_in_and_fan_out def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def build_model_with_cfg( model_cls, variant: str, pretrained: bool, default_cfg: dict, model_cfg=None, feature_cfg=None, pretrained_strict: bool = True, pretrained_filter_fn=None, pretrained_custom_load=False, kwargs_filter=None, **kwargs, ): """Build model with specified default_cfg and optional model_cfg This helper fn aids in the construction of a model including: * handling default_cfg and associated pretained weight loading * passing through optional model_cfg for models with config based arch spec * features_only model adaptation * pruning config / model adaptation Args: model_cls (nn.Module): model class variant (str): model variant name pretrained (bool): load pretrained weights default_cfg (dict): model's default pretrained/task config model_cfg (Optional[Dict]): model's architecture config feature_cfg (Optional[Dict]: feature extraction adapter config pretrained_strict (bool): load pretrained weights strictly pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model **kwargs: model args passed through to model __init__ """ pruned = kwargs.pop("pruned", False) features = False feature_cfg = feature_cfg or {} default_cfg = deepcopy(default_cfg) if default_cfg else {} update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) default_cfg.setdefault("architecture", variant) # Setup for feature extraction wrapper done at end of this fn if kwargs.pop("features_only", False): features = True feature_cfg.setdefault("out_indices", (0, 1, 2, 3, 4)) if "out_indices" in kwargs: feature_cfg["out_indices"] = kwargs.pop("out_indices") # Build the model model = ( model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) ) model.default_cfg = default_cfg # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = ( 0 if features else getattr(model, "num_classes", kwargs.get("num_classes", 1000)) ) if pretrained: assert not pretrained_custom_load, "URL should not contain npz for PASST models" load_pretrained( model, num_classes=num_classes_pretrained, in_chans=kwargs.get("in_chans", 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict, ) return model ================================================ FILE: models/passt.py ================================================ """ Most of this code comes from the timm library. We tried to disentangle from the timm library version. Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ import math import logging import warnings from functools import partial import collections from collections import OrderedDict from copy import deepcopy from itertools import repeat import torch import torch.nn as nn import torch.nn.functional as F from .helpers.vit_helpers import update_default_cfg_and_kwargs, DropPath, trunc_normal_, build_model_with_cfg _logger = logging.getLogger() # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_2tuple = _ntuple(2) IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { # patch models (weights from official Google JAX impl) 'vit_tiny_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_tiny_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch32_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_base_patch32_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_base_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch32_224': _cfg( url='', # no official model weights for this combo, only for in21k ), 'vit_large_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_large_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/' 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_tiny_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_small_patch32_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_small_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_base_patch32_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_base_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', num_classes=21843), 'vit_large_patch32_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843), 'vit_large_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', num_classes=21843), 'vit_huge_patch14_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub='timm/vit_huge_patch14_224_in21k', num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) 'vit_base_patch32_sam_224': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), 'vit_base_patch16_sam_224': _cfg( url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), # deit models (FB weights) 'deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 'deit_tiny_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_small_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil_in21k': _cfg( url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, ), 'vit_base_patch16_224_miil': _cfg( url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' '/vit_base_patch16_224_1k_miil_84_4.pth', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', ), # PaSST 'passt_s_swa_p16_128_ap476': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_kd_p16_128_ap486': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_l_kd_p16_128_ap47': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_swa_p16_128_ap4761': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_p16_128_ap472': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.472.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_p16_s16_128_ap468': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_swa_p16_s16_128_ap473': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.473-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_swa_p16_s14_128_ap471': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.471-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_p16_s14_128_ap469': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.469.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_swa_p16_s12_128_ap473': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.473-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_p16_s12_128_ap470': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.470.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_swa_f128_stfthop100_p16_s10_ap473': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop100-p16-s10-ap.473-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_swa_f128_stfthop160_p16_s10_ap473': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-20sec-p16-s10-ap.474-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-30sec-p16-s10-ap.473-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), 'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=20), 'openmic2008_passt_u_f128_p16_s10_ap85 ': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=20), } def adapt_input_conv(in_chans, conv_weight): conv_type = conv_weight.dtype conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU O, I, J, K = conv_weight.shape if in_chans == 1: if I > 3: assert conv_weight.shape[1] % 3 == 0 # For models with space2depth stems conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) conv_weight = conv_weight.sum(dim=2, keepdim=False) else: conv_weight = conv_weight.sum(dim=1, keepdim=True) elif in_chans != 3: if I != 3: raise NotImplementedError('Weight format not supported by conversion.') else: # NOTE this strategy should be better than random init, but there could be other combinations of # the original RGB input layer weights that'd work better for specific cases. repeat = int(math.ceil(in_chans / 3)) conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] conv_weight *= (3 / float(in_chans)) conv_weight = conv_weight.to(conv_type) return conv_weight class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x first_RUN = True PLUS1_TRICK = False class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) self.img_size = img_size self.patch_size = patch_size self.stride = stride self.grid_size = (img_size[0] // stride[0], img_size[1] // stride[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape if not (H == self.img_size[0] and W == self.img_size[1]): warnings.warn(f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") # to do maybe replace weights x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) if first_RUN: print("self.norm(x)", x.size()) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale if PLUS1_TRICK: # +1 trick attn = torch.cat([attn, torch.zeros(attn.shape[:-1]+(1,), dtype=attn.dtype, device=attn.device)], dim=-1) attn = attn.softmax(dim=-1) if PLUS1_TRICK: # +1 trick attn = attn[...,:-1] attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PaSST(nn.Module): """ Based on the implementation of Vision Transformer in timm library. Take a look at the get_model function, adapting the weights of pretrained imagenet models. """ def __init__(self, u_patchout=0, s_patchout_t=0, s_patchout_f=0, img_size=(128, 998), patch_size=16, stride=16, in_chans=1, num_classes=527, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init=''): """ Args: u_patchout: Unstructured Patchout integer, number of items to be removed from the final sequence s_patchout_t: structured Patchout time integer, number of columns to be removed from the patches grid s_patchout_f: structured Patchout Frequency integer, number of rows to be removed from the patches grid img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set distilled (bool): model includes a distillation token and head as in DeiT models drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer weight_init: (str): weight init scheme """ super().__init__() self.num_classes = num_classes self.u_patchout = u_patchout self.s_patchout_t = s_patchout_t self.s_patchout_f = s_patchout_f self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, stride=stride, in_chans=in_chans, embed_dim=embed_dim, flatten=False) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None # PaSST # refer to https://arxiv.org/abs/2110.05069 Section 2 self.new_pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) # for C and D tokens self.freq_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.patch_embed.grid_size[0], 1)) # | f self.time_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, self.patch_embed.grid_size[1])) # __ t #### self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # Representation layer if representation_size and not distilled: self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ('fc', nn.Linear(embed_dim, representation_size)), ('act', nn.Tanh()) ])) else: self.pre_logits = nn.Identity() # Classifier head(s) self.head = nn.Sequential(nn.LayerNorm(self.num_features), nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) self.head_dist = None if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.init_weights(weight_init) def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.new_pos_embed, std=.02) trunc_normal_(self.freq_new_pos_embed, std=.02) trunc_normal_(self.time_new_pos_embed, std=.02) if self.dist_token is not None: trunc_normal_(self.dist_token, std=.02) if mode.startswith('jax'): # leave cls token as zeros to match jax impl raise RuntimeError("Not supported yet") else: trunc_normal_(self.cls_token, std=.02) self.apply(_init_vit_weights) def _init_weights(self, m): # this fn left here for compat with downstream users _init_vit_weights(m) @torch.jit.ignore def no_weight_decay(self): return {'new_pos_embed', 'freq_new_pos_embed', 'time_new_pos_embed', 'cls_token', 'dist_token'} def get_classifier(self): if self.dist_token is None: return self.head else: return self.head, self.head_dist def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.num_tokens == 2: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): global first_RUN # not jit friendly? use trace instead x = self.patch_embed(x) # [b, e, f, t] B_dim, E_dim, F_dim, T_dim = x.shape # slow if first_RUN: print(" patch_embed : ", x.shape) # Adding Time/Freq information if first_RUN: print(" self.time_new_pos_embed.shape", self.time_new_pos_embed.shape) time_new_pos_embed = self.time_new_pos_embed if x.shape[-1] < time_new_pos_embed.shape[-1]: if self.training: toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item() if first_RUN: print(f" CUT with randomoffset={toffset} time_new_pos_embed.shape", time_new_pos_embed.shape) time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]] else: time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]] if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape) else: warnings.warn( f"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut") x = x[:, :, :, :time_new_pos_embed.shape[-1]] x = x + time_new_pos_embed if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape) x = x + self.freq_new_pos_embed # Structured Patchout https://arxiv.org/abs/2110.05069 Section 2.2 if self.training and self.s_patchout_t: if first_RUN: print(f"X Before time Patchout of {self.s_patchout_t} ", x.size()) # ([1, 768, 1, 82]) random_indices = torch.randperm(T_dim)[:T_dim - self.s_patchout_t].sort().values x = x[:, :, :, random_indices] if first_RUN: print("X after time Patchout", x.size()) if self.training and self.s_patchout_f: if first_RUN: print(f"X Before Freq Patchout of {self.s_patchout_f} ", x.size()) # [1, 768, 12, 1] random_indices = torch.randperm(F_dim)[:F_dim - self.s_patchout_f].sort().values x = x[:, :, random_indices, :] if first_RUN: print(" \n X after freq Patchout: ", x.size()) ### # Flatten the sequence x = x.flatten(2).transpose(1, 2) # Unstructured Patchout if first_RUN: print("X flattened", x.size()) if self.training and self.u_patchout: seq_len = x.shape[1] random_indices = torch.randperm(seq_len)[:seq_len - self.u_patchout].sort().values x = x[:, random_indices, :] if first_RUN: print("X After Unstructured Patchout", x.size()) #### # Add the C/D tokens if first_RUN: print(" self.new_pos_embed.shape", self.new_pos_embed.shape) cls_tokens = self.cls_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, :1, :] if first_RUN: print(" self.cls_tokens.shape", cls_tokens.shape) if self.dist_token is None: x = torch.cat((cls_tokens, x), dim=1) else: dist_token = self.dist_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, 1:, :] if first_RUN: print(" self.dist_token.shape", dist_token.shape) x = torch.cat((cls_tokens, dist_token, x), dim=1) if first_RUN: print(" final sequence x", x.shape) x = self.pos_drop(x) x = self.blocks(x) if first_RUN: print(f" after {len(self.blocks)} atten blocks x", x.shape) x = self.norm(x) if self.dist_token is None: return self.pre_logits(x[:, 0]) else: return x[:, 0], x[:, 1] def forward(self, x): global first_RUN if first_RUN: print("x", x.size()) x = self.forward_features(x) if self.head_dist is not None: features = (x[0] + x[1]) / 2 if first_RUN: print("forward_features", features.size()) x = self.head(features) if first_RUN: print("head", x.size()) first_RUN = False return x, features else: features = x if first_RUN: print("forward_features", features.size()) x = self.head(x) if first_RUN: print("head", x.size()) first_RUN = False return x, features def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): """ ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) elif name.startswith('pre_logits'): lecun_normal_(module.weight) nn.init.zeros_(module.bias) else: if jax_impl: nn.init.xavier_uniform_(module.weight) if module.bias is not None: if 'mlp' in name: nn.init.normal_(module.bias, std=1e-6) else: nn.init.zeros_(module.bias) else: trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) elif jax_impl and isinstance(module, nn.Conv2d): # NOTE conv was left to pytorch default in my original init lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.zeros_(module.bias) nn.init.ones_(module.weight) def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='bicubic'): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, posemb_new.shape, num_tokens) ntok_new = posemb_new.shape[1] if num_tokens: posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= num_tokens else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 assert len(gs_new) >= 2 _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, gs_new=(), mode='bicubic'): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, gs_new, num_tokens) if num_tokens: posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) assert len(gs_new) >= 2 _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False) freq_new_pos_embed = posemb_grid.mean(dim=3, keepdim=True) time_new_pos_embed = posemb_grid.mean(dim=2, keepdim=True) _logger.info('New Position cls/dstl embedding %s', posemb_tok.shape) _logger.info('New FREQ Position embedding %s', freq_new_pos_embed.shape) _logger.info('New TIME Position embedding %s', time_new_pos_embed.shape) return posemb_tok, freq_new_pos_embed, time_new_pos_embed def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] state_dict = {k: v for k, v in state_dict.items()} if "time_new_pos_embed" not in state_dict: # we are working with ImageNet model _logger.info("Adapting pos embedding from ImageNet pretrained model to PaSST.") v = state_dict.pop("pos_embed") new_pos_embed, freq_new_pos_embed, time_new_pos_embed = adapt_image_pos_embed_to_passt( v, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) state_dict["new_pos_embed"] = new_pos_embed state_dict["freq_new_pos_embed"] = freq_new_pos_embed state_dict["time_new_pos_embed"] = time_new_pos_embed for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # this should never occur v = resize_pos_embed( v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) out_dict[k] = v return out_dict def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): default_cfg = default_cfg or default_cfgs[variant] if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') # NOTE this extra code to support handling of repr size for in21k pretrained models default_num_classes = default_cfg['num_classes'] num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: # Remove representation layer if fine-tuning. This may not always be the desired action, # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? _logger.warning("Removing representation layer for fine-tuning.") repr_size = None model = build_model_with_cfg( PaSST, variant, pretrained, default_cfg=default_cfg, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load='npz' in default_cfg['url'], **kwargs) return model def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ model_kwargs = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) return model def deit_base_distilled_patch16_384(pretrained=False, **kwargs): """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ print("\n\n Loading DEIT BASE 384\n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=476 SWA \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (10, 10): warnings.warn( f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (10, 10): warnings.warn( f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=7, num_heads=12, **kwargs) if model_kwargs.get("stride") != (10, 10): warnings.warn( f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=4763 SWA \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (10, 10): warnings.warn( f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_swa_p16_128_ap4761', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_p16_128_ap472(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (10, 10): warnings.warn( f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_p16_128_ap472', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (12, 12): warnings.warn( f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_p16_s12_128_ap470', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs): print("\n\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( 'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs): print("\n\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( 'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (12, 12): warnings.warn( f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_swa_p16_s12_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (14, 14): warnings.warn( f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_p16_s14_128_ap469', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (14, 14): warnings.warn( f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_swa_p16_s14_128_ap471', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (16, 16): warnings.warn( f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_swa_p16_s16_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs) return model def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n") model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) if model_kwargs.get("stride") != (16, 16): warnings.warn( f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") model = _create_vision_transformer( 'passt_s_p16_s16_128_ap468', pretrained=pretrained, distilled=True, **model_kwargs) return model from ba3l.ingredients.ingredient import Ingredient model_ing = Ingredient("passt") model_ing.add_config(instance_cmd="get_model") @model_ing.command def fix_embedding_layer(model, embed="default"): if embed == "default": return model if embed == "overlap": model.patch_embed = PatchEmbedAdaptiveMean(replace=model.patch_embed) if embed == "am_keepconv": model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed) return model @model_ing.command def lighten_model(model, cut_depth=0): if cut_depth == 0: return model if cut_depth: if cut_depth < 0: print(f"\n Reducing model depth by removing every {-cut_depth} layer \n\n") else: print(f"\n Reducing model depth by {cut_depth} \n\n") if len(model.blocks) < cut_depth + 2: raise ValueError(f"Cut depth a VIT with {len(model.blocks)} " f"layers should be between 1 and {len(model.blocks) - 2}") print(f"\n Before Cutting it was {len(model.blocks)} \n\n") old_blocks = list(model.blocks.children()) if cut_depth < 0: print(f"cut_depth={cut_depth}") old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]] else: old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:] model.blocks = nn.Sequential(*old_blocks) print(f"\n Atfer Cutting it is {len(model.blocks)} \n\n") return model @model_ing.command def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1, fstride=10, tstride=10, input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=0, s_patchout_f=0, ): """ :param arch: Base ViT or Deit architecture :param pretrained: use pretrained model on imagenet :param n_classes: number of classes :param in_channels: number of input channels: 1 for mono :param fstride: the patches stride over frequency. :param tstride: the patches stride over time. :param input_fdim: the expected input frequency bins. :param input_tdim: the expected input time bins. :param u_patchout: number of input patches to drop in Unstructured Patchout as defined in https://arxiv.org/abs/2110.05069 :param s_patchout_t: number of input time frames to drop Structured Patchout as defined in https://arxiv.org/abs/2110.05069 :param s_patchout_f: number of input frequency bins to drop Structured Patchout as defined in https://arxiv.org/abs/2110.05069 :param audioset_pretrain: use pretrained models on Audioset. :return: """ model_func = None input_size = (input_fdim, input_tdim) stride = (fstride, tstride) if arch == "passt_deit_bd_p16_384": # base deit model_func = deit_base_distilled_patch16_384 elif arch == "passt_s_kd_p16_128_ap486": # pretrained model_func = passt_s_kd_p16_128_ap486 elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L model_func = passt_l_kd_p16_128_ap47 elif arch == "passt_s_swa_p16_128_ap476": # pretrained model_func = passt_s_swa_p16_128_ap476 elif arch == "passt_s_swa_p16_128_ap4761": model_func = passt_s_swa_p16_128_ap4761 elif arch == "passt_s_p16_128_ap472": model_func = passt_s_p16_128_ap472 elif arch == "passt_s_p16_s16_128_ap468": model_func = passt_s_p16_s16_128_ap468 elif arch == "passt_s_swa_p16_s16_128_ap473": model_func = passt_s_swa_p16_s16_128_ap473 elif arch == "passt_s_swa_p16_s14_128_ap471": model_func = passt_s_swa_p16_s14_128_ap471 elif arch == "passt_s_p16_s14_128_ap469": model_func = passt_s_p16_s14_128_ap469 elif arch == "passt_s_swa_p16_s12_128_ap473": model_func = passt_s_swa_p16_s12_128_ap473 elif arch == "passt_s_p16_s12_128_ap470": model_func = passt_s_p16_s12_128_ap470 elif arch == "passt_s_f128_20sec_p16_s10_ap474": model_func = passt_s_f128_20sec_p16_s10_ap474_swa elif arch == "passt_s_f128_30sec_p16_s10_ap473": model_func = passt_s_f128_30sec_p16_s10_ap473_swa if model_func is None: raise RuntimeError(f"Unknown model {arch}") model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels, img_size=input_size, stride=stride, u_patchout=u_patchout, s_patchout_t=s_patchout_t, s_patchout_f=s_patchout_f) model = fix_embedding_layer(model) model = lighten_model(model) print(model) return model class EnsembelerModel(nn.Module): def __init__(self, models): super(EnsembelerModel, self).__init__() self.models = nn.ModuleList(models) def forward(self, x): # ModuleList can act as an iterable, or be indexed using ints all_out = None for i, m in enumerate(self.models): out, _ = m(x) if all_out is None: all_out = out else: all_out = out + all_out all_out = all_out / len(self.models) return all_out, all_out @model_ing.command def get_ensemble_model(arch_list=[]): # arch_list = [(passt_s_swa_p16_128_ap476,fstride,tstride)] models_list = [get_model(arch=arch, fstride=fstride, tstride=tstride) for arch, fstride, tstride in arch_list] model = EnsembelerModel(models_list) print(model) return model ================================================ FILE: models/preprocess.py ================================================ import torch.nn as nn import torchaudio from torch.nn.functional import conv1d, conv2d import torch from ba3l.ingredients.ingredient import Ingredient model_ing = Ingredient("spectrograms") sz_float = 4 # size of a float epsilon = 10e-8 # fudge factor for normalization @model_ing.command class AugmentMelSTFT(nn.Module): def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=1, fmax_aug_range=1000): torch.nn.Module.__init__(self) # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e # Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast self.win_length = win_length self.n_mels = n_mels self.n_fft = n_fft self.sr = sr self.htk = htk self.fmin = fmin if fmax is None: fmax = sr // 2 - fmax_aug_range // 2 print(f"Warning: FMAX is None setting to {fmax} ") self.fmax = fmax self.norm = norm self.hopsize = hopsize self.register_buffer('window', torch.hann_window(win_length, periodic=False), persistent=False) assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation" assert fmin_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation" self.fmin_aug_range = fmin_aug_range self.fmax_aug_range = fmax_aug_range self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False) if freqm == 0: self.freqm = torch.nn.Identity() else: self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) if timem == 0: self.timem = torch.nn.Identity() else: self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True) def forward(self, x): x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1) x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length, center=True, normalized=False, window=self.window, return_complex=False) x = (x ** 2).sum(dim=-1) # power mag fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item() fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item() # don't augment eval data if not self.training: fmin = self.fmin fmax = self.fmax mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr, fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0) mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0), device=x.device) with torch.cuda.amp.autocast(enabled=False): melspec = torch.matmul(mel_basis, x) melspec = (melspec + 0.00001).log() if self.training: melspec = self.freqm(melspec) melspec = self.timem(melspec) melspec = (melspec + 4.5) / 5. # fast normalization return melspec def extra_repr(self): return 'winsize={}, hopsize={}'.format(self.win_length, self.hopsize ) ================================================ FILE: openmic/README.md ================================================ # Experiments on OpenMIC-2018 [OpenMIC-2018](https://github.com/cosmir/openmic-2018) ([zenodo](https://zenodo.org/record/1432913#.W6dPeJNKjOR)) is a dataset for polyphonic instruments identification. ## Preparing the dataset Use `openmic2008/prepare_scripts/download_preprocess.py` to download and pack the dataset: ```shell cd openmic2008/prepare_scripts/ python download_preprocess.py ``` When the script completes, you should have two files inside `audioset_hdf5s/mp3/` `openmic_train.csv_mp3.hdf` and `openmic_test.csv_mp3.hdf` these files contains the mp3s of the dataset and the labels. ## Fine-tuning pretrained PaSST on the openmic2008 Similar to audioset you can use: ```shell # Example call with all the default config: python ex_openmic.py with trainer.precision=16 -p ``` ```shell # with 2 gpus: DDP=2 python ex_openmic.py with trainer.precision=16 -p ``` ================================================ FILE: openmic/dataset.py ================================================ import io import os import pathlib import random import av import librosa import torchaudio from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler import torch from ba3l.ingredients.datasets import Dataset import pandas as pd from sacred.config import DynamicIngredient, CMD from scipy.signal import convolve from sklearn import preprocessing from torch.utils.data import Dataset as TorchDataset import numpy as np import h5py from helpers.audiodatasets import PreprocessDataset LMODE = os.environ.get("LMODE", False) dataset = Dataset('openMIC') @dataset.config def default_config(): name = 'openmic2008' # dataset name normalize = False # normalize dataset subsample = False # subsample squares from the dataset roll = True # apply roll augmentation fold = 1 base_dir = "audioset_hdf5s/" # base directory of the dataset as downloaded if LMODE: base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/" openmic_train_hdf5 = base_dir + "mp3/openmic_train.csv_mp3.hdf" openmic_test_hdf5 = base_dir + "mp3/openmic_test.csv_mp3.hdf" ir_path = base_dir + "irs/" num_of_classes = 20 def decode_mp3(mp3_arr): """ decodes an array if uint8 representing an mp3 file :rtype: np.array """ container = av.open(io.BytesIO(mp3_arr.tobytes())) stream = next(s for s in container.streams if s.type == 'audio') # print(stream) a = [] for i, packet in enumerate(container.demux(stream)): for frame in packet.decode(): a.append(frame.to_ndarray().reshape(-1)) waveform = np.concatenate(a) if waveform.dtype != 'float32': raise RuntimeError("Unexpected wave type") return waveform def pad_or_truncate(x, audio_length): """Pad all audio to specific length.""" if len(x) <= audio_length: return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) else: return x[0: audio_length] irs_arr = None @dataset.command def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None): if not ir_augment: return global irs_arr if irs_arr is None: all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')] all_paths = sorted(all_paths) if cut_irs_offset is not None: all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10] all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths] print("will use these IRs:") for i in range(len(all_paths_name)): print(i, ": ", all_paths_name[i]) _run.info["ir_devices"] = all_paths_name irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths] return irs_arr[int(np.random.randint(0, len(irs_arr)))] @dataset.command def pydub_augment(waveform, gain_augment=7, ir_augment=0): if ir_augment and torch.rand(1) < ir_augment: ir = get_ir_sample() waveform = convolve(waveform, ir, 'full') if gain_augment: gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment amp = 10 ** (gain / 20) waveform = waveform * amp return waveform class MixupDataset(TorchDataset): """ Mixing Up wave forms """ def __init__(self, dataset, beta=2, rate=0.5): self.beta = beta self.rate = rate self.dataset = dataset print(f"Mixing up waveforms from dataset of len {len(dataset)}") def __getitem__(self, index): x1, f1, y1 = self.dataset[index] y1 = torch.as_tensor(y1) if torch.rand(1) < self.rate: idx2 = torch.randint(len(self.dataset), (1,)).item() x2, f2, y2 = self.dataset[idx2] y2 = torch.as_tensor(y2) l = np.random.beta(self.beta, self.beta) l = max(l, 1. - l) x1 = x1 - x1.mean() x2 = x2 - x2.mean() x = (x1 * l + x2 * (1. - l)) x = x - x.mean() assert len(y1) == 40, "only for openmic, works this" y_mask1 = (torch.as_tensor(y1[20:]) > 0.5).float() y_mask2 = (torch.as_tensor(y2[20:]) > 0.5).float() y1[:20] *= y_mask1 y2[:20] *= y_mask2 yres = (y1 * l + y2 * (1. - l)) yres[20:] = torch.stack([y_mask1, y_mask2]).max(dim=0).values return x, f1, yres return x1, f1, y1 def __len__(self): return len(self.dataset) class AudioSetDataset(TorchDataset): def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False): """ Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav """ self.sample_rate = sample_rate self.hdf5_file = hdf5_file if in_mem: print("\nPreloading in memory\n") with open(hdf5_file, 'rb') as f: self.hdf5_file = io.BytesIO(f.read()) with h5py.File(hdf5_file, 'r') as f: self.length = len(f['audio_name']) print(f"Dataset from {hdf5_file} with length {self.length}.") self.dataset_file = None # lazy init self.clip_length = clip_length * sample_rate self.classes_num = classes_num self.augment = augment if augment: print(f"Will agument data from {hdf5_file}") def open_hdf5(self): self.dataset_file = h5py.File(self.hdf5_file, 'r') def __len__(self): return self.length def __del__(self): if self.dataset_file is not None: self.dataset_file.close() self.dataset_file = None def __getitem__(self, index): """Load waveform and target of an audio clip. Args: meta: { 'hdf5_path': str, 'index_in_hdf5': int} Returns: data_dict: { 'audio_name': str, 'waveform': (clip_samples,), 'target': (classes_num,)} """ if self.dataset_file is None: self.open_hdf5() audio_name = self.dataset_file['audio_name'][index].decode() waveform = decode_mp3(self.dataset_file['mp3'][index]) if self.augment: waveform = pydub_augment(waveform) waveform = pad_or_truncate(waveform, self.clip_length) waveform = self.resample(waveform) target = self.dataset_file['target'][index] target = target.astype(np.float32) return waveform.reshape(1, -1), audio_name, target def resample(self, waveform): """Resample. Args: waveform: (clip_samples,) Returns: (resampled_clip_samples,) """ if self.sample_rate == 32000: return waveform elif self.sample_rate == 16000: return waveform[0:: 2] elif self.sample_rate == 8000: return waveform[0:: 4] else: raise Exception('Incorrect sample rate!') @dataset.command def get_base_training_set(openmic_train_hdf5): ds = AudioSetDataset(openmic_train_hdf5, augment=True) return ds @dataset.command def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"), epoch_len=100000, sampler_replace=False): num_nodes = int(os.environ.get('num_nodes', 1)) ddp = int(os.environ.get('DDP', 1)) num_nodes = max(ddp, num_nodes) print("num_nodes= ", num_nodes) rank = int(os.environ.get('NODE_RANK', 0)) return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights, num_samples=epoch_len, replacement=sampler_replace), dataset=range(epoch_len), num_replicas=num_nodes, rank=rank, ) @dataset.command def get_base_test_set(openmic_test_hdf5): ds = AudioSetDataset(openmic_test_hdf5) return ds @dataset.command(prefix='roll_conf') def get_roll_func(axis=1, shift=None, shift_range=50): print("rolling...") def roll_func(b): x, i, y = b x = torch.as_tensor(x) sf = shift if shift is None: sf = int(np.random.random_integers(-shift_range, shift_range)) global FirstTime return x.roll(sf, axis), i, y return roll_func @dataset.command def get_training_set(normalize, roll, wavmix=False): ds = get_base_training_set() get_ir_sample() if normalize: print("normalized train!") fill_norms() ds = PreprocessDataset(ds, norm_func) if roll: ds = PreprocessDataset(ds, get_roll_func()) if wavmix: ds = MixupDataset(ds) return ds @dataset.command def get_test_set(normalize): ds = get_base_test_set() if normalize: print("normalized test!") fill_norms() ds = PreprocessDataset(ds, norm_func) return ds @dataset.command def print_conf(_config): print("Config of ", dataset.path, id(dataset)) print(_config) print() class DistributedSamplerWrapper(DistributedSampler): def __init__( self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True): super(DistributedSamplerWrapper, self).__init__( dataset, num_replicas, rank, shuffle) # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238 self.sampler = sampler def __iter__(self): if self.sampler.generator is None: self.sampler.generator = torch.Generator() self.sampler.generator.manual_seed(self.seed + self.epoch) indices = list(self.sampler) if self.epoch == 0: print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n") indices = indices[self.rank:self.total_size:self.num_replicas] return iter(indices) if __name__ == "__main__": from sacred import Experiment ex = Experiment("test_dataset", ingredients=[dataset]) @ex.automain def default_command(): ex.current_run.get_command_function("print_config")() get_base_training_set() ds = get_test_set() print(ds[0]) ds = get_training_set() print(ds[0]) print("get_base_training_set", len(get_base_training_set())) print("get_base_test_set", len(get_base_test_set())) print("get_training_set", len(get_training_set())) print("get_test_set", len(get_test_set())) ================================================ FILE: openmic/prepare_scripts/download_preprocess.py ================================================ import os import tarfile import multiprocessing import glob import h5py import numpy as np from torch.hub import download_url_to_file # global constants openmicurl = "https://zenodo.org/record/1432913/files/openmic-2018-v1.0.0.tgz?download=1" download_target = "openmic-2018-v1.0.0.tgz" extract_target = download_target.replace(".tgz", "") dataset_path = os.path.join(extract_target, "openmic-2018/") mp3_path = os.path.join(dataset_path, "mp3/") hdf5s_dir = "../../audioset_hdf5s/" train_files_csv = os.path.join(dataset_path, "partitions/split01_train.csv") test_files_csv = os.path.join(dataset_path, "partitions/split01_test.csv") def download(force=False): if force or not os.path.isfile(download_target): print("Downloading OpenMIC from zenodo...") download_url_to_file(openmicurl, download_target) else: print(f"{download_target} already exists. Skipping download!") def untar(): my_tar = tarfile.open(download_target) print(f"Extracting openmic from {download_target} to {extract_target}") my_tar.extractall(extract_target) def process_folder(fol="balanced_train_segments"): print("now working on ", fol) os.makedirs(mp3_path + fol, exist_ok=True) all_files = list(glob.glob(os.path.join(dataset_path, "audio/") + "/*/*.ogg")) # openmic format print(f"it has {len(all_files)}") print(all_files[:5]) global all_num all_num = len(all_files) cmds = [(i, file, mp3_path + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)] print(cmds[0]) with multiprocessing.Pool(processes=20) as pool: pool.starmap(process_one, cmds) def process_one(i, f1, f2): if i % 100 == 0: print(f"{i}/{all_num} \t", f1) os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3") def make_mp3(): process_folder("audio") def read_metadata(csv_path, classes_num, id_to_ix, openmicf): """Read metadata of AudioSet from a csv file. source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59 Args: csv_path: str Returns: meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} """ with open(csv_path) as file: lines = file.readlines() lines = [line.rstrip() for line in lines] audios_num = len(lines) # class + mask targets = np.zeros((audios_num, classes_num * 2), dtype=np.float32) audio_names = [] notfound = set() for i, row in enumerate(lines): audio_name = '{}.mp3'.format(row) # Audios are started with an extra 'Y' when downloading audio_names.append(audio_name) t = openmicf.Y_true[id_to_ix[row]] #id_to_ix[row]["true"] m = openmicf.Y_mask[id_to_ix[row]].astype(int)#id_to_ix[row]["mask"] # Target targets[i, :classes_num] = t targets[i, classes_num:] = m print(notfound) print("original_targets", len(targets)) mask = targets.astype(np.int).sum(1) > 0 print(len(mask), mask.sum()) print("after: ", len(targets[mask])) meta_dict = {'audio_name': np.array(audio_names)[mask], 'target': targets[mask]} return meta_dict def get_files_labels(balanced_csv, balanced_audio_path, d_files, openmicf, prefix=None, zip_contents=None, classes_num=20): meta_csv = read_metadata(balanced_csv, classes_num, d_files, openmicf) # fname,labels,mids audios_num = len(meta_csv['audio_name']) found = 0 notfound = 0 available_files = [] available_targets = [] for n in range(audios_num): audio_path = meta_csv['audio_name'][n] # print(balanced_audio_path + f"{prefix}/{audio_path}") if n == 0: print("checking: ", balanced_audio_path + f"{prefix}/{audio_path}") if os.path.isfile(balanced_audio_path + f"{prefix}/{audio_path}"): found += 1 available_files.append(meta_csv['audio_name'][n]) available_targets.append(meta_csv['target'][n]) else: notfound += 1 print(f"Found {found} . not found {notfound}") return available_files, available_targets def pack(): d_files = dict() opmic = np.load(os.path.join(dataset_path, "openmic-2018.npz")) opmic.allow_pickle = True for i, sid in enumerate(opmic.f.sample_key): d_files[sid] = i #{"mask": opmic.f.Y_mask[i].astype(int), #"true": opmic.f.Y_true[i]} print("len=",len(d_files)) for read_file, prefix in [(train_files_csv, "audio/"), (test_files_csv, "audio/")]: print("now working on ", read_file, prefix) # files, y = torch.load(read_file+".pth") files, y = get_files_labels(read_file, mp3_path, d_files=d_files, openmicf=opmic.f, prefix=prefix) y = np.array(y) # y = np.packbits(y, axis=-1) packed_len = y.shape[1] print(files[0], "classes: ", packed_len, y.dtype) available_size = len(files) f = files[0] a = np.fromfile(mp3_path + prefix + "/" + f, dtype='uint8') dt = h5py.vlen_dtype(np.dtype('uint8')) save_file = read_file.rsplit("/", 1)[1].replace("split01", "openmic") os.makedirs(hdf5s_dir + "mp3/" ,exist_ok=True) if os.path.isfile(hdf5s_dir + "mp3/" + save_file + "_mp3.hdf"): print(hdf5s_dir + "mp3/" + save_file + "_mp3.hdf", "exists!\n\n\n contiue") continue with h5py.File(hdf5s_dir + "mp3/" + save_file + "_mp3.hdf", 'w') as hf: audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20') waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt) target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype) for i, file in enumerate(files): if i % 1000 == 0: print(f"{i}/{available_size}") f = file a = np.fromfile(mp3_path + prefix + f, dtype='uint8') audio_name[i] = f waveform[i] = a target[i] = y[i] print("Saved h5py file into ", hdf5s_dir + "mp3/" + save_file + "_mp3.hdf") print(a.shape) print("Done!", prefix) def preprocess(): download() untar() make_mp3() pack() preprocess() ================================================ FILE: pip_list.txt ================================================ Package Version ----------------------- ------------ absl-py 1.4.0 aiohttp 3.8.4 aiosignal 1.3.1 appdirs 1.4.4 async-timeout 4.0.2 attrs 22.2.0 audioread 3.0.0 autopep8 1.6.0 av 10.0.0 brotlipy 0.7.0 cachetools 5.3.0 certifi 2022.12.7 cffi 1.15.1 charset-normalizer 2.0.4 click 8.1.3 colorama 0.4.6 cryptography 39.0.1 decorator 5.1.1 docker-pycreds 0.4.0 docopt 0.6.2 flit_core 3.8.0 frozenlist 1.3.3 fsspec 2023.3.0 future 0.18.3 gitdb 4.0.10 GitPython 3.1.31 google-auth 2.17.1 google-auth-oauthlib 0.4.6 grpcio 1.53.0 h5py 3.8.0 idna 3.4 importlib-metadata 6.1.0 joblib 1.2.0 jsonpickle 3.0.1 kk-sacred 0.8.4 lazy_loader 0.2 librosa 0.10.0.post2 lightning-utilities 0.8.0 llvmlite 0.39.1 Markdown 3.4.3 MarkupSafe 2.1.2 mkl-fft 1.3.1 mkl-random 1.2.2 mkl-service 2.4.0 msgpack 1.0.5 multidict 6.0.4 munch 2.5.0 numba 0.56.4 numpy 1.23.5 oauthlib 3.2.2 packaging 23.0 pandas 1.5.3 pathtools 0.1.2 Pillow 9.4.0 pip 23.0.1 pooch 1.6.0 protobuf 4.22.1 psutil 5.9.4 py-cpuinfo 9.0.0 pyasn1 0.4.8 pyasn1-modules 0.2.8 pycodestyle 2.10.0 pycparser 2.21 pyDeprecate 0.3.0 pyOpenSSL 23.0.0 PySocks 1.7.1 python-dateutil 2.8.2 pytorch-lightning 1.3.1 pytz 2023.3 PyYAML 5.4.1 requests 2.28.1 requests-oauthlib 1.3.1 rsa 4.9 scikit-learn 1.2.2 scipy 1.10.1 sentry-sdk 1.19.1 setproctitle 1.3.2 setuptools 65.6.3 six 1.16.0 smmap 5.0.0 soundfile 0.12.1 soxr 0.3.4 tensorboard 2.12.0 tensorboard-data-server 0.7.0 tensorboard-plugin-wit 1.8.1 threadpoolctl 3.1.0 timm 0.4.12 toml 0.10.2 torch 1.11.0 torchaudio 0.11.0 torchmetrics 0.2.0 torchvision 0.12.0 tqdm 4.65.0 typing_extensions 4.4.0 urllib3 1.26.15 wandb 0.14.2 Werkzeug 2.2.3 wheel 0.38.4 wrapt 1.15.0 yarl 1.8.2 zipp 3.15.0 ================================================ FILE: requirements.txt ================================================ av>=10.0.0 h5py>=3.8.0 jsonpickle>=3.0.1 kk-sacred>=0.8.4 librosa>=0.10.0.post2 timm>=0.4.12 torchmetrics>=0.2.0 pytorch-lightning<2.0.0 wandb>=0.14.2