Repository: kkoutini/PaSST
Branch: main
Commit: 2a5c818afcc2
Files: 52
Total size: 341.1 KB
Directory structure:
gitextract_blhd6h9_/
├── .github/
│ └── workflows/
│ └── pylint.yml
├── .gitignore
├── .markdownlint.json
├── LICENSE
├── README.md
├── __init__.py
├── audioset/
│ ├── README.md
│ ├── dataset.py
│ └── prepare_scripts/
│ ├── convert_to_mp3.py
│ └── create_h5pymp3_dataset.py
├── ba3l/
│ ├── __init__.py
│ ├── experiment.py
│ ├── ingredients/
│ │ ├── __init__.py
│ │ ├── datasets.py
│ │ ├── ingredient.py
│ │ ├── models.py
│ │ └── trainer.py
│ ├── module.py
│ ├── plutils/
│ │ ├── __init__.py
│ │ ├── lr_monitor.py
│ │ └── progress_bar.py
│ └── util/
│ ├── __init__.py
│ ├── functions.py
│ └── sacred_logger.py
├── config_updates.py
├── environment.yml
├── environment_old_2021.yml
├── esc50/
│ ├── README.md
│ └── dataset.py
├── ex_audioset.py
├── ex_esc50.py
├── ex_fsd50k.py
├── ex_openmic.py
├── fsd50k/
│ ├── README.md
│ ├── dataset.py
│ └── prepare_scripts/
│ ├── convert_to_mp3.py
│ └── create_h5pymp3_dataset.py
├── helpers/
│ ├── audiodatasets.py
│ ├── mixup.py
│ ├── models_size.py
│ ├── ramp.py
│ ├── swa_callback.py
│ ├── swa_legacy.py
│ └── workersinit.py
├── models/
│ ├── helpers/
│ │ └── vit_helpers.py
│ ├── passt.py
│ └── preprocess.py
├── openmic/
│ ├── README.md
│ ├── dataset.py
│ └── prepare_scripts/
│ └── download_preprocess.py
├── pip_list.txt
└── requirements.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/pylint.yml
================================================
name: Pylint
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
================================================
FILE: .gitignore
================================================
# project
/output/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
example.py
timit_data/
LJSpeech-1.1/
# C extensions
*.so
.idea/
# Distribution / packaging
.Python
env/
ide_layouts/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
/openmic/prepare_scripts/openmic-2018-v1.0.0/
/openmic/prepare_scripts/openmic-2018-v1.0.0.tgz
/audioset_hdf5s/mp3/openmic_test.csv_mp3.hdf
/audioset_hdf5s/mp3/openmic_train.csv_mp3.hdf
environment_builds.yml
# Output
lightning_logs/*
audioset_hdf5s/*
.vscode/settings.json
wandb/*
.vscode/launch.json
================================================
FILE: .markdownlint.json
================================================
{
"MD033": {
"allowed_elements": [
"p",
"img"
]
},
"MD013": false
}
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# PaSST: Efficient Training of Audio Transformers with Patchout
This is the implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069)
Patchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.

Patchout works by dropping out some of the input patches during training.
In either an unstructured way (randomly, similar to dropout),
or entire time-frames or frequency bins of the extracted patches (similar to SpecAugment),
which corresponds to rows/columns in step 3 of the figure below.

## Table of contents
- [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions)
- [Getting the logits from the pretrained models](#getting-the-logits-from-the-pretrained-models)
- [Getting a pre-trained model for fine-tuning](#getting-a-pre-trained-model-for-fine-tuning)
- [Development environment](#development-environment)
- [Setting up the development experiments environment](#setting-up-the-development-experiments-environment)
- [Setting up using the exported conda environment](#setting-up-using-the-exported-conda-environment)
- [Checking the environment](#checking-the-environment)
- [Getting started](#getting-started)
- [General information](#general-information)
- [Configuring the experiment](#configuring-the-experiment)
- [Training on Audioset](#training-on-audioset)
- [Examples with Pre-trained models](#examples-with-pre-trained-models)
- [Examples fine-tuning on downstream datasets](#examples-of-fine-tuning-on-downstream-datasets)
- [Citation](#citation)
- [Contact](#contact)
## Pre-trained models for Inference and embeddings extractions
If you only want to use the embeddings generated by the pretrained models, use
your own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo [here](https://github.com/kkoutini/passt_hear21).
The package follows [HEAR 2021 NeurIPS Challenge](https://neuralaudio.ai/hear2021-results.html) API, and can be installed:
```shell
pip install hear21passt
```
This repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks.
### Getting the logits from the pretrained models
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
print(model.net) # the transformer network.
# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
# audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
# example audio_wave of batch=3 and 10 seconds
audio = torch.ones((3, 32000 * 10))*0.5
audio_wave = audio.cuda()
logits=model(audio_wave)
```
### Getting a pre-trained model for fine tuning
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50)
print(model.net) # the transformer network.
# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k
model.train()
model = model.cuda()
```
## Development environment
If you want to use the same environment as in the paper, you can follow the instructions below.
### Setting up the development experiments environment
For training models from scratch or fine-tuning using the same setup as in the paper:
1. If needed, create a new environment with python 3.8 and activate it:
```bash
conda create -n passt python=3.8
conda activate passt
```
1. Install pytorch build that suits your system. For example:
```bash
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
```
1. Install the requirements:
```bash
pip install -r requirements.txt
```
### Setting up using the exported conda environment
Alternatively, you can use the exported conda environment `environment.yml` to create the environment.
For setting up [Mamba](https://github.com/mamba-org/mamba) is recommended since it works faster than `conda`:
```shell
conda install mamba -n base -c conda-forge
```
Now you can import the environment from `environment.yml`
```shell
mamba env create -f environment.yml
```
Now you have an environment named `ba3l`.
### Checking the environment
In order to check if your environment matched the environment we used in our runs, please check the `environment.yml` and `pip_list.txt` files, which were exported using:
```shell
conda env export --no-builds | grep -v "prefix" > environment.yml
pip list > pip_list.txt
```
## Getting started
If you want to use your setup and only use the models from this repo, you can get the models train them from scratch or fine-tune them on your own dataset as explained above [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions). The rest of this section explains using this repo for training and fine-tuning the models. For that, first you need to set up the development environment as explained above.
### General information
The repo is built using [sacred](https://sacred.readthedocs.io/en/) for experiment management and configuration, pytorch-lightning for training, and wandb for logging.
Each dataset has a main experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder. The experiment file contains the main training and validation logic. The dataset folder contains the dataset specific code needed to download, preprocess and load the dataset for training.
In general, you can prob the experiment file for help, this will print the available commands and basic options:
```shell
python ex_audioset.py help
```
### Configuring the experiment
Each experiment has a set of default configuration options, defined in the experiment file, e.g. `ex_audioset.py`. You can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html). You can use the `print_config` command to print the configuration values without training a model:
```shell
python ex_audioset.py print_config
```
You can use then use the command line interface to override any of the configuration options ([sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html)), using `with` e.g.:
```shell
python ex_audioset.py with trainer.precision=16
```
This will train on Audioset using 16-bit precision.
The overall configurations look like this:
```yaml
...
seed = 542198583 # the random seed for this experiment
slurm_job_id = ''
speed_test_batch_size = 100
swa = True
swa_epoch_start = 50
swa_freq = 5
use_mixup = True
warm_up_len = 5
weight_decay = 0.0001
basedataset:
base_dir = 'audioset_hdf5s/' # base directory of the dataset, change it or make a link
eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
wavmix = 1
....
roll_conf:
axis = 1
shift = None
shift_range = 50
datasets:
test:
batch_size = 20
dataset = {CMD!}'/basedataset.get_test_set'
num_workers = 16
validate = True
training:
batch_size = 12
dataset = {CMD!}'/basedataset.get_full_training_set'
num_workers = 16
sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
shuffle = None
train = True
models:
mel:
freqm = 48
timem = 192
hopsize = 320
htk = False
n_fft = 1024
n_mels = 128
norm = 1
sr = 32000
...
net:
arch = 'passt_s_swa_p16_128_ap476'
fstride = 10
in_channels = 1
input_fdim = 128
input_tdim = 998
n_classes = 527
s_patchout_f = 4
s_patchout_t = 40
tstride = 10
u_patchout = 0
...
trainer:
accelerator = None
accumulate_grad_batches = 1
amp_backend = 'native'
amp_level = 'O2'
auto_lr_find = False
auto_scale_batch_size = False
...
```
There are many things that can be updated from the command line.
In short:
- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line.
- `wandb` is the wandb configuration. For example, to change the wandb project `wandb.project="test_project"` to the command line.
- `models.net` are the PaSST (or the chosen NN) options. Examples: `models.net.u_patchout`, `models.net.s_patchout_f` `models.net.s_patchout_t` control the unstructured patchout and structured patchout over frequency and time. `input_fdim` and `input_tdim` are the input spectrogram dimensions over frequency and time. `models.net.fstride` and `models.net.tstride` are the strides of the input patches over frequency and time, setting these to 16 means no patch overlap.
- `models.mel` are the preprocessing options (mel spectrograms). `mel.sr` is the sampling rate, `mel.hopsize` is the hop size of the STFT window, `mel.n_mels` is the number of mel bins, `mel.freqm` and `mel.timem` are the frequency and time masking parameters of spec-augment.
There are many pre-defined configuration bundles (called named_configs) in `config_updates.py`. These include different models, setups etc...
You can list these configurations with:
```shell
python ex_audioset.py print_named_configs
```
For example, `passt_s_20sec` is a configuration bundle that sets the model to PaSST-S pre-trained on Audioset, and accepts up to 20 second clips.
## Training on Audioset
Download and prepare the dataset as explained in the [audioset page](audioset/)
The base PaSST model can be trained for example like this:
```bash
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p
```
For example using only unstructured patchout of 400:
```bash
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 models.net.u_patchout=400 models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p
```
Multi-gpu training can be enabled by setting the environment variable `DDP`, for example with 2 gpus:
```shell
DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -c "PaSST base 2 GPU"
```
## Examples with Pre-trained models
Please check the [releases page](https://github.com/kkoutini/PaSST/releases/), to download pre-trained models.
In general, you can get a pretrained model on Audioset using
```python
from models.passt import get_model
model = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1,
fstride=10, tstride=10,input_fdim=128, input_tdim=998,
u_patchout=0, s_patchout_t=40, s_patchout_f=4)
```
this will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs.
There are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`.
For example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479.
In general, you can get a pretrained model using:
```python
from models.passt import get_model
passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)
```
Using the framework, you can evaluate this model using:
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 passt_s_swa_p16_s16_128_ap473 -p
```
Ensemble of these models are provided as well:
A large ensemble giving `mAP=.4956`
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
```
An ensemble of 2 models with `stride=14` and `stride=16` giving `mAP=.4858`
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
```
As well as other ensembles `ensemble_4`, `ensemble_5`
## Examples of fine-tuning on downstream datasets
1. [ESC-50: Dataset for Environmental Sound Classification](esc50/)
2. [OpenMIC-2018 dataset](openmic/)
3. [FSD50K](fsd50k/)
## Citation
The citation to the accepted paper in Interspeech 2022:
```bib
@inproceedings{koutini22passt,
author = {Khaled Koutini and
Jan Schl{\"{u}}ter and
Hamid Eghbal{-}zadeh and
Gerhard Widmer},
title = {Efficient Training of Audio Transformers with Patchout},
booktitle = {Interspeech 2022, 23rd Annual Conference of the International Speech
Communication Association, Incheon, Korea, 18-22 September 2022},
pages = {2753--2757},
publisher = {{ISCA}},
year = {2022},
url = {https://doi.org/10.21437/Interspeech.2022-227},
doi = {10.21437/Interspeech.2022-227},
}
```
## Contact
The repo will be updated, in the meantime if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.
================================================
FILE: __init__.py
================================================
================================================
FILE: audioset/README.md
================================================
# Experiments on Audioset
Audioset has around 2M segments. The total size of the dataset with wav files with `32khz` sampling rate is around 1.2 TB. In our setup, this results in a huge IO bottleneck that slows down the training process significantly.
Therefore, we encode the dataset to mp3, pack the mp3 into HDF5 format and decode the mp3s on the fly, If you have enough cpu cores (10-16 dataloading workers) you should not notice any slowdowns.
in the `dataset.py` file we read the samples from the hdf files. Decode the mp3, do wave form augmentations and return the raw waveform of the model.
`AudioSetDataset` is the main class where reading from the hdf files.
## Preparing the dataset
### Downloading Audioset
We used the scripts provided by [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn) to download the dataset.
### Converting to mp3
Once the Datasets are downloaded we convert all the files to mp3 using the script:
`prepare_scripts/convert_to_mp3.py`.
```bash
python convert_to_mp3.py --source pann_download_folder --out mp3_folder
```
this will significantly reduce the size of the dataset and overcome the IO bottleneck in our setup. The trade-off is that more cpu is needed during training to decode the mp3s.
We use the [av](https://pypi.org/project/av/) (check `decode_mp3` in ` dataset.py`) library to decode the mp3 in the data loading workers, this is much faster than calling ffmpeg.
As a result, approximetly 10 decoding threads should be enough keep a 2080ti busy.
you can test how much time it take to load and decode one epoch on your system:
```bash
python python ex_audioset.py test_loaders_train_speed
```
This step is not necessary if you have a more powerful setup and the `decode_mp3` also supports other ffmpeg codecs.
### packing to HDF5 files
Finally, you need to pack the mp3 files into a single HDF5 file using `create_h5pymp3_dataset.py`.
you just need to set the paths in the script to match your local paths. The script goes through the csv files and check if the corresponding mp3 file exists, then it will store it in h5py file.
The output of this step should be 3 files `balanced_train_segments_mp3.hdf`, `eval_segments_mp3.hdf` and `unbalanced_train_segments_mp3.hdf`.
Each of these files. Make sure the paths match the default config in `dataset.py`
================================================
FILE: audioset/dataset.py
================================================
import io
import os
import pathlib
import random
import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler
import torch
from ba3l.ingredients.datasets import Dataset
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
import numpy as np
from helpers.audiodatasets import PreprocessDataset
import h5py
LMODE = os.environ.get("LMODE", False)
#$TMPDIR
dataset = Dataset('audiodataset')
@dataset.config
def default_config():
name = 'audioset' # dataset name
normalize = False # normalize dataset
subsample = False # subsample squares from the dataset
roll = True # apply roll augmentation
fold = 1
base_dir = "audioset_hdf5s/" # base directory of the dataset, change it or make a link
if LMODE:
base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/"
balanced_train_hdf5 = base_dir + "mp3/balanced_train_segments_mp3.hdf"
eval_hdf5 = base_dir + "mp3/eval_segments_mp3.hdf"
unbalanced_train_hdf5 = base_dir + "mp3/unbalanced_train_segments_mp3.hdf"
if LMODE:
balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
unbalanced_train_hdf5 = unbalanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
ir_path = base_dir + "irs/"
num_of_classes = 527
if LMODE:
@dataset.config
def LMODE_default_config():
cache_root_path = "/system/user/publicdata/CP/DCASE/cached_datasets/"
def decode_mp3(mp3_arr):
"""
decodes an array if uint8 representing an mp3 file
:rtype: np.array
"""
container = av.open(io.BytesIO(mp3_arr.tobytes()))
stream = next(s for s in container.streams if s.type == 'audio')
# print(stream)
a = []
for i, packet in enumerate(container.demux(stream)):
for frame in packet.decode():
a.append(frame.to_ndarray().reshape(-1))
waveform = np.concatenate(a)
if waveform.dtype != 'float32':
raise RuntimeError("Unexpected wave type")
return waveform
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
else:
return x[0: audio_length]
irs_arr = None
@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
if not ir_augment:
return
global irs_arr
if irs_arr is None:
all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
if cut_irs_offset is not None:
all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
print("will use these IRs:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])
_run.info["ir_devices"] = all_paths_name
irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
return irs_arr[int(np.random.randint(0, len(irs_arr)))]
@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
if ir_augment and torch.rand(1) < ir_augment:
ir = get_ir_sample()
waveform = convolve(waveform, ir, 'full')
if gain_augment:
gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
amp = 10 ** (gain / 20)
waveform = waveform * amp
return waveform
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
if torch.rand(1) < self.rate:
x1, f1, y1 = self.dataset[index]
idx2 = torch.randint(len(self.dataset), (1,)).item()
x2, f2, y2 = self.dataset[idx2]
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1-x1.mean()
x2 = x2-x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
return x, f1, (y1 * l + y2 * (1. - l))
return self.dataset[index]
def __len__(self):
return len(self.dataset)
class AudioSetDataset(TorchDataset):
def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False):
"""
Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
"""
self.sample_rate = sample_rate
self.hdf5_file = hdf5_file
if in_mem:
print("\nPreloading in memory\n")
with open(hdf5_file, 'rb') as f:
self.hdf5_file = io.BytesIO(f.read())
with h5py.File(hdf5_file, 'r') as f:
self.length = len(f['audio_name'])
print(f"Dataset from {hdf5_file} with length {self.length}.")
self.dataset_file = None # lazy init
self.clip_length = clip_length * sample_rate
self.classes_num = classes_num
self.augment = augment
if augment:
print(f"Will agument data from {hdf5_file}")
def open_hdf5(self):
self.dataset_file = h5py.File(self.hdf5_file, 'r')
def __len__(self):
return self.length
def __del__(self):
if self.dataset_file is not None:
self.dataset_file.close()
self.dataset_file = None
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
meta: {
'hdf5_path': str,
'index_in_hdf5': int}
Returns:
data_dict: {
'audio_name': str,
'waveform': (clip_samples,),
'target': (classes_num,)}
"""
if self.dataset_file is None:
self.open_hdf5()
audio_name = self.dataset_file['audio_name'][index].decode()
waveform = decode_mp3(self.dataset_file['mp3'][index])
if self.augment:
waveform = pydub_augment(waveform)
waveform = pad_or_truncate(waveform, self.clip_length)
waveform = self.resample(waveform)
target = self.dataset_file['target'][index]
target = np.unpackbits(target, axis=-1,
count=self.classes_num).astype(np.float32)
return waveform.reshape(1, -1), audio_name, target
def resample(self, waveform):
"""Resample.
Args:
waveform: (clip_samples,)
Returns:
(resampled_clip_samples,)
"""
if self.sample_rate == 32000:
return waveform
elif self.sample_rate == 16000:
return waveform[0:: 2]
elif self.sample_rate == 8000:
return waveform[0:: 4]
else:
raise Exception('Incorrect sample rate!')
@dataset.command
def get_base_training_set(balanced_train_hdf5):
ds = AudioSetDataset(balanced_train_hdf5, augment=True)
return ds
@dataset.command
def get_unbalanced_training_set(unbalanced_train_hdf5):
ds = AudioSetDataset(unbalanced_train_hdf5, augment=True)
return ds
@dataset.command
def get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5):
ds = ConcatDataset(
[AudioSetDataset(balanced_train_hdf5, augment=False), AudioSetDataset(unbalanced_train_hdf5, augment=False)])
return ds
@dataset.command
def get_base_full_training_set():
sets = [get_base_training_set(), get_unbalanced_training_set()]
ds = ConcatDataset(sets)
return ds
@dataset.command
def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes):
for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
print(f"\n \n will now preload {hdf5_file} \n\n ")
with h5py.File(hdf5_file, 'r') as dataset_file:
target = dataset_file['mp3'][:]
print(len(target))
print(f"\n \n done with {hdf5_file} \n\n ")
return target[1000]
@dataset.command
def get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes,
sample_weight_offset=100, sample_weight_sum=True):
"""
:return: float tenosr of shape len(full_training_set) representing the weights of each sample.
"""
# the order of balanced_train_hdf5,unbalanced_train_hdf5 is important.
# should match get_full_training_set
all_y = []
for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
with h5py.File(hdf5_file, 'r') as dataset_file:
target = dataset_file['target']
target = np.unpackbits(target, axis=-1,
count=num_of_classes)
all_y.append(target)
all_y = np.concatenate(all_y, axis=0)
all_y = torch.as_tensor(all_y)
per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class
per_class = sample_weight_offset + per_class # offset low freq classes
if sample_weight_offset > 0:
print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}")
per_class_weights = 1000. / per_class
all_weight = all_y * per_class_weights
# print(all_weight.shape)
# print(all_weight[1510])
if sample_weight_sum:
print("\nsample_weight_sum\n")
all_weight = all_weight.sum(dim=1)
else:
all_weight, _ = all_weight.max(dim=1)
# print(all_weight.shape)
# print(all_weight[1510])
return all_weight
@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
epoch_len=100000, sampler_replace=False):
num_nodes = int(os.environ.get('num_nodes', 1))
ddp = int(os.environ.get('DDP', 1))
num_nodes = max(ddp, num_nodes)
print("num_nodes= ", num_nodes)
rank = int(os.environ.get('NODE_RANK', 0))
return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
num_samples=epoch_len, replacement=sampler_replace),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
@dataset.command
def get_base_test_set(eval_hdf5):
ds = AudioSetDataset(eval_hdf5)
return ds
@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
print("rolling...")
def roll_func(b):
x, i, y = b
x = torch.as_tensor(x)
sf = shift
if shift is None:
sf = int(np.random.random_integers(-shift_range, shift_range))
global FirstTime
return x.roll(sf, axis), i, y
return roll_func
@dataset.command
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_full_training_set(normalize, roll, wavmix=False):
ds = get_base_full_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_test_set(normalize):
ds = get_base_test_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def print_conf(_config):
print("Config of ", dataset.path, id(dataset))
print(_config)
print()
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset,
num_replicas=None,
rank=None,
shuffle: bool = True):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch == 0:
print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n")
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
if __name__ == "__main__":
from sacred import Experiment
ex = Experiment("test_dataset", ingredients=[dataset])
@ex.automain
def default_command():
ex.current_run.get_command_function("print_config")()
get_base_training_set()
ds = get_test_set()
print(ds[0])
ds = get_training_set()
print(ds[0])
print("get_base_training_set", len(get_base_training_set()))
print("get_base_test_set", len(get_base_test_set()))
print("get_training_set", len(get_training_set()))
print("get_test_set", len(get_test_set()))
================================================
FILE: audioset/prepare_scripts/convert_to_mp3.py
================================================
import argparse
import multiprocessing
import glob
import os
# Replace this with the dataset downloaded using PANN scripts
source_path = "/share/cp/datasets/full_audioset/audioset201906/audios/"
# Replace with the output directory
out_path = "./mp3_audio/"
all_num = 0
def process_folder(fol="balanced_train_segments"):
print("now working on ", fol)
os.makedirs(out_path + fol, exist_ok=True)
all_files = list(glob.glob(source_path + fol + "/*.wav"))
print(f"it has {len(all_files)}")
global all_num
all_num = len(all_files)
cmds = [(i, file, out_path + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]
print(cmds[0])
with multiprocessing.Pool(processes=20) as pool:
pool.starmap(process_one, cmds)
def process_one(i, f1, f2):
if i % 100 == 0:
print(f"{i}/{all_num} \t", f1)
os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--source', type=str, required=False, default=None,
help='Path of folder containing the wave files, expected format to be as downloaded '
'using the PANNs script, containing balanced_train_segments, eval_segments, '
'unbalanced_train_segments folders.')
parser.add_argument('--out', type=str, required=False, default=None,
help='Directory to save out the converted mp3s.')
args = parser.parse_args()
source_path = args.source or source_path
out_path = args.out or out_path
folders = ['balanced_train_segments', 'eval_segments'] + ["unbalanced_train_segments/" + x for x in sorted(
os.listdir(source_path + "unbalanced_train_segments/"))]
print("I will work on these folders:")
print(folders)
for fol in folders:
process_folder(fol)
================================================
FILE: audioset/prepare_scripts/create_h5pymp3_dataset.py
================================================
# %%
import h5py
import pandas as pd
import numpy as np
import csv
import os
# %%
base_dir = "../../audioset_hdf5s/"
balanced_csv= base_dir+ "metadata/balanced_train_segments.csv"
eval_csv= base_dir+ "metadata/eval_segments.csv"
mp3_path = "../../mp3_audio/"
# %%
def read_metadata(csv_path, classes_num, id_to_ix):
"""Read metadata of AudioSet from a csv file.
source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59
Args:
csv_path: str
Returns:
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
"""
with open(csv_path, 'r') as fr:
lines = fr.readlines()
lines = lines[3:] # Remove heads
audios_num = len(lines)
targets = np.zeros((audios_num, classes_num), dtype=np.bool)
audio_names = []
for n, line in enumerate(lines):
items = line.split(', ')
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
audio_name = 'Y{}.mp3'.format(items[0]) # Audios are started with an extra 'Y' when downloading
label_ids = items[3].split('"')[1].split(',')
audio_names.append(audio_name)
# Target
for id in label_ids:
ix = id_to_ix[id]
targets[n, ix] = 1
meta_dict = {'audio_name': np.array(audio_names), 'target': targets}
return meta_dict
# Load label
with open(base_dir+'metadata/class_labels_indices.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
lines = list(reader)
labels = []
ids = [] # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
id = lines[i1][1]
label = lines[i1][2]
ids.append(id)
labels.append(label)
classes_num = len(labels)
lb_to_ix = {label : i for i, label in enumerate(labels)}
ix_to_lb = {i : label for i, label in enumerate(labels)}
id_to_ix = {id : i for i, id in enumerate(ids)}
ix_to_id = {i : id for i, id in enumerate(ids)}
# %%
def check_available(balanced_csv,balanced_audio_path,prefix=None):
meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix)
audios_num = len(meta_csv['audio_name'])
found=0
notfound=0
available_files=[]
available_targets=[]
if prefix is None:
prefix = os.path.basename(balanced_csv)[:-4]
for n in range(audios_num):
audio_path = meta_csv['audio_name'][n]
#print(balanced_audio_path + f"{prefix}/{audio_path}")
if os.path.isfile(balanced_audio_path + f"{prefix}/{audio_path}" ):
found+=1
available_files.append(meta_csv['audio_name'][n])
available_targets.append(meta_csv['target'][n])
else:
notfound+=1
print(f"Found {found} . not found {notfound}")
return available_files,available_targets
# %%
# %%
# %%
for read_file,prefix in [(balanced_csv,"balanced_train_segments/"), (eval_csv,"eval_segments/"),]:
print("now working on ",read_file,prefix)
#files, y = torch.load(read_file+".pth")
files, y = check_available(read_file, mp3_path)
y = np.packbits(y, axis=-1)
packed_len = y.shape[1]
print(files[0], "classes: ",packed_len, y.dtype)
available_size = len(files)
f = files[0][:-3]+"mp3"
a = np.fromfile(mp3_path+prefix + "/"+f, dtype='uint8')
dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = prefix.split("/")[0]
with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf:
audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
for i,file in enumerate(files):
if i%1000==0:
print(f"{i}/{available_size}")
f = file[:-3] + "mp3"
a = np.fromfile(mp3_path + prefix + f, dtype='uint8')
audio_name[i]=f
waveform[i] = a
target[i] = y[i]
print(a.shape)
print("Done!" , prefix)
# %%
print("working on unbalanced...")
all_x,all_y = None, None
for idx in range(41):
print("working on ",idx)
tmp_csv = base_dir+ f"metadata/unbalanced_partial_csvs/unbalanced_train_segments_part{idx:02}.csv"
prefix = f"unbalanced_train_segments/unbalanced_train_segments_part{idx:02}"
x,y = check_available(tmp_csv,mp3_path,prefix=prefix)
x = np.array([f"{idx:02}/"+one for one in x])
y=np.packbits(y, axis=-1)
print("x,y",x.shape, y.shape)
if all_x is None:
all_x = x
all_y = y
else:
all_x = np.concatenate((all_x,x))
all_y = np.concatenate((all_y,y))
print(f"done {idx}! all x,y",all_x.shape, all_y.shape)
print("now working on packing unbalanced")
prefix = "unbalanced_train_segments/unbalanced_train_segments_part"
files = all_x
y = all_y
packed_len = y.shape[1]
print(files[0], "classes: ",packed_len, y.dtype)
available_size = len(files)
f = files[0][:-3]+"mp3"
a = np.fromfile(mp3_path+prefix + f, dtype='uint8')
dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = prefix.split("/")[0]
with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf:
audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
for i,file in enumerate(files):
if i%1000==0:
print(f"{i}/{available_size}")
f = file[:-3] + "mp3"
a = np.fromfile(mp3_path + prefix + f, dtype='uint8')
audio_name[i]=f
waveform[i] = a
target[i] = y[i]
print(a.shape)
print("Done!" , prefix)
================================================
FILE: ba3l/__init__.py
================================================
"""Package info"""
__version__ = "0.0.2"
__author__ = "Koutini et al."
__license__ = "Apache-2.0"
__copyright__ = "Copyright (c) 2019-, %s." % __author__
__homepage__ = "https://github.com//"
__docs__ = (
"The researcher friendly pytorch environment. Ba3l= sacred+pytorch-lightening."
)
__author_email__ = "first.last@jku.at"
================================================
FILE: ba3l/experiment.py
================================================
import inspect
from importlib import import_module
from ba3l.ingredients.datasets import Datasets
from ba3l.ingredients.models import Models, Model
#from ba3l.trainer import Trainer
from ba3l.util.sacred_logger import SacredLogger
from sacred import Experiment as Sacred_Experiment, Ingredient
from typing import Sequence, Optional, List
from sacred.commandline_options import CLIOption
from sacred.config import CMD
from sacred.host_info import HostInfoGetter
from sacred.utils import PathType, optional_kwargs_decorator
from pytorch_lightning import loggers as pl_loggers
from ba3l.util.functions import get_default_kwargs_dict
def ingredients_recursive_apply(ing, fn):
fn(ing)
for kid in ing.ingredients:
ingredients_recursive_apply(kid, fn)
def config_recursive_apply(conf, fn):
for k,v in conf.items():
if isinstance(v, dict):
config_recursive_apply(v,fn)
else:
fn(k,v)
def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):
loggers = []
if use_sacred_logger:
loggers.append( SacredLogger(expr))
if use_tensorboard_logger:
loggers.append(pl_loggers.TensorBoardLogger(sacred_logger.name))
return loggers
class Experiment(Sacred_Experiment):
"""
Main Ba3l Experiment class overrides sacred experiments.
"""
def __init__(
self,
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
datasets: Optional[Ingredient] = None,
models: Optional[Ingredient] = None,
interactive: bool = False,
base_dir: Optional[PathType] = None,
additional_host_info: Optional[List[HostInfoGetter]] = None,
additional_cli_options: Optional[Sequence[CLIOption]] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients. (from Sacred)
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
if models is None:
models = Models.get_instance()
self.models = models
if datasets is None:
datasets = Datasets.get_instance()
self.datasets = datasets
if ingredients is None:
ingredients = []
ingredients = list(ingredients) + [models, datasets]
caller_globals = inspect.stack()[1][0].f_globals
self.last_default_configuration_position = 0
super().__init__(
name=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
additional_host_info=additional_host_info,
additional_cli_options=additional_cli_options,
save_git_info=save_git_info,
caller_globals=caller_globals
)
def get_run_identifier(self):
return str(self.current_run.db_identifier) \
+ "_" + str(self.current_run._id)
def get_dataloaders(self, filter={}):
results = {}
for ds in self.datasets.get_datasets(filter):
results[ds.name] = ds.get_iterator()
if len(results) == 1:
for k, v in results.items():
return v
return results
def get_train_dataloaders(self):
return self.get_dataloaders(dict(train=True))
def get_val_dataloaders(self):
return self.get_dataloaders(dict(validate=True))
def _create_run(
self,
command_name=None,
config_updates=None,
named_configs=(),
info=None,
meta_info=None,
options=None,
dry_run=False,
):
if self.current_run is not None:
# @todo replace with logger
print("Warning: multiple runs are not yet supported")
run = super()._create_run(
command_name,
config_updates,
named_configs,
info,
meta_info,
options,
dry_run=False,
)
# self.current_run=run
# def update_current_run(ing):
# ing.current_run = run
#
# ingredients_recursive_apply(self, update_current_run)
return run
@optional_kwargs_decorator
def command(
self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={},
**extra_args
):
"""
Decorator to define a new Command.
a command is a function whose parameters are filled automatically by sacred.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param add_default_args_config: wether to add the default arguments of the function to the config automatically.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
add_default_args_config = (not unobserved) and add_default_args_config
if add_default_args_config:
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[function.__name__] = captured_f
return captured_f
def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):
"""
adds the default parameters of a function to the ingredient config at lowest priority!
Default args config is meant remove the need to declare all the configurations manually.
:param f: the function
"""
# @todo get the doc of the params as well
config_candidate = {**get_default_kwargs_dict(function), **extra_args}
# remove "static_args" from config
for k in static_args:
config_candidate.pop(k, None)
# respect the prefix for the added default parameters
if prefix is not None:
for pr in prefix.split('.')[::-1]:
config_candidate={pr: config_candidate}
self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))
self.last_default_configuration_position += 1
================================================
FILE: ba3l/ingredients/__init__.py
================================================
================================================
FILE: ba3l/ingredients/datasets.py
================================================
import inspect
import os
from functools import partial
from ba3l.util.functions import get_default_kwargs_dict
from sacred.config import CMD
from .ingredient import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch
def raise_(ex):
raise ex
class Dataset(Ingredient):
"""
The class that annotates a Dateset of Ba3l experiment
a Dataset can be
"""
DATASET_STRING_PREFIX = "get_dataset"
ITER_STRING_PREFIX = "get_iterator"
def __init__(
self,
name: str,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "dataset"
self.name = name.rsplit(".", 1)[-1]
super().__init__(
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_dataset_command = None
self.get_dataset_iterator_command = None
self.current_run = None
self.get_dataset = lambda: raise_(
NotImplementedError(
"Use dataset.dataset_name.dataset to annotate the "
"get_dataset function!."
)
)
self.get_iter = lambda: raise_(
NotImplementedError(
"Use dataset.dataset_name.iter to annotate the " "get_iter function!."
)
)
@optional_kwargs_decorator
def dataset(
self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
):
"""
Decorator to define a new Dataset.
The name of the dataset is used to get an instance of the dataset, it will register a command
Datasets are sacred commands.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[Dataset.DATASET_STRING_PREFIX] = captured_f
self.get_dataset = captured_f
self.add_config(dataset=CMD("get_dataset"))
return captured_f
@optional_kwargs_decorator
def iter(
self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
):
"""
Decorator to define a new Iterator.
The name of the iterator is used to get an instance of the iterator, it will register a command
iterator are sacred commands.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
"""
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[Dataset.ITER_STRING_PREFIX] = captured_f
self.get_iter = captured_f
return captured_f
# def get_dataset(self):
# assert self.current_run is not None, "Can only be called during a run."
# return self.commands[Datasets.DATASET_STRING_PREFIX + name]()
# # return self.current_run.get_command_function(
# # self.path + "." + Datasets.DATASET_STRING_PREFIX + name)()
# #
def __getattr__(self, k):
if k == "iterator":
return self.__getattribute__("iter")
if k == "get_iterator":
return self.__getattribute__("get_iter")
super().__getattribute__(k)
# @todo maybe run commands from here after running
class Datasets(Ingredient, Munch):
"""
The class that encapsulates all the datasets in an experiment
"""
__instance = None
@classmethod
def get_instance(cls):
if Datasets.__instance is None:
Datasets.__instance = Datasets()
return Datasets.__instance
def __init__(
self,
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "datasets"
Ingredient.__init__(
self,
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_datasets_list_command = None
self.get_dataset_command = None
self.get_dataset_iterator_command = None
self.current_run = None
self.get_dataset = None
# self.command(get_dataset_iterator_command, unobserved=True)
def __getattr__(self, k):
""" Gets key if it exists, otherwise returns the default value."""
try:
return Munch.__getattr__(self, k)
except AttributeError:
return self.__getitem__(k)
def __setattr__(self, k, v):
try:
# Throws exception if not in prototype chain
object.__getattribute__(self, k)
except AttributeError:
try:
self[k] = v
except:
raise AttributeError(k)
else:
object.__setattr__(self, k, v)
def __getitem__(self, k):
""" Gets key if it exists, otherwise returns the default value."""
try:
return Munch.__getitem__(self, k)
except KeyError:
self[k] = Dataset(
self.path + "." + k,
base_dir=self.base_dir,
save_git_info=self.save_git_info,
)
assert self
self.ingredients.append(self[k])
return self[k]
def __hash__(self):
return Ingredient.__hash__(self)
def get_datasets(self, config_conditions={}, return_datasets_names=False):
""" Return all the datasets whose configuration matches config_conditions."""
results = []
for dataset in self.ingredients:
all_ok = True
for cond_k, cond_v in config_conditions.items():
if (
self.current_run.get_config_path_value(dataset.path + "." + cond_k)
!= cond_v
):
all_ok = False
break
if all_ok:
if return_datasets_names:
results.append((dataset))
results.append(dataset)
return results
================================================
FILE: ba3l/ingredients/ingredient.py
================================================
import inspect
import os
from functools import partial
from ba3l.util.functions import get_default_kwargs_dict
from sacred import Ingredient as sacred_Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch
def raise_(ex):
raise ex
class Ingredient(sacred_Ingredient):
"""
The class that annotates a Dateset of Ba3l experiment
a Dataset can be
"""
def __init__(
self,
path: str,
ingredients: Sequence[sacred_Ingredient] = (),
interactive: bool = False,
_caller_globals: Optional[dict] = None,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
The Base Ingredient of all Ba3l ingredients
Parameters
----------
path
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
_caller_globals = _caller_globals or inspect.stack()[1][0].f_globals
if path is None:
path = "Ingredient"
super().__init__(
path=path,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=_caller_globals,
save_git_info=save_git_info,
)
self.current_run = None
self.last_default_configuration_position = 0
def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):
"""
adds the default parameters of a function to the ingredient config at lowest priority!
Default args config is meant remove the need to declare all the configurations manually.
:param f: the function
"""
# @todo get the doc of the params as well
config_candidate = {**get_default_kwargs_dict(function), **extra_args}
# remove "static_args" from config
for k in static_args:
config_candidate.pop(k, None)
if prefix is not None:
for pr in prefix.split('.')[::-1]:
config_candidate={pr: config_candidate}
self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))
self.last_default_configuration_position += 1
@optional_kwargs_decorator
def command(
self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={}, **extra_args
):
"""
Decorator to define a new Command.
a command is a function whose parameters are filled automatically by sacred.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
if add_default_args_config:
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[function.__name__] = captured_f
return captured_f
================================================
FILE: ba3l/ingredients/models.py
================================================
import inspect
import os
from .ingredient import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType
import inspect
import os
from functools import partial
from ba3l.util.functions import get_default_kwargs_dict
from sacred.config import CMD
from .ingredient import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch
class Models(Ingredient):
"""
The class that annotates the models of Ba3l experiment
"""
__instance = None
@classmethod
def get_instance(cls):
if Models.__instance is None:
Models.__instance = Models()
return Models.__instance
def __init__(
self,
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "models"
super().__init__(
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_models_command = None
self.current_run = None
# self.command(print_config, unobserved=True)
def raise_(ex):
raise ex
class Model(Ingredient):
"""
The class that annotates a Dateset of Ba3l experiment
a Dataset can be
"""
MODEL_STRING_PREFIX = "get_instance"
def __init__(
self,
name: str,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "model"
self.name = name.rsplit(".", 1)[-1]
super().__init__(
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_instance_command = None
self.current_run = None
self.get_instance = lambda: raise_(
NotImplementedError(
"Use dataset.dataset_name.dataset to annotate the "
"get_dataset function!."
)
)
@optional_kwargs_decorator
def instance(
self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
):
"""
Decorator to define a new model.
The name of the model is used to get an instance of the model, it will register a command
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[Model.MODEL_STRING_PREFIX] = captured_f
self.get_instance = captured_f
self.add_config(get_instance=CMD("get_instance"))
return captured_f
def __getattr__(self, k):
if k == "get_instance":
return self.__getattribute__("get_instance")
super().__getattribute__(k)
# @todo maybe run commands from here after running
================================================
FILE: ba3l/ingredients/trainer.py
================================================
import inspect
import os
from ba3l.util.functions import get_default_kwargs_dict
from .datasets import Datasets, raise_
from .models import Models
from sacred import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
class Trainer(Ingredient):
"""
The class that annotates the main Trainer of Ba3l experiment
"""
TRAINER_STRING_PREFIX = "get_trainer"
================================================
FILE: ba3l/module.py
================================================
import pytorch_lightning as pl
import warnings
from abc import ABC
import torch.distributed as dist
from munch import DefaultMunch
try:
# loading for pyTorch 1.3
from torch.utils.data import IterableDataset
except ImportError:
# loading for pyTorch 1.1
import torch
warnings.warn(
"Your version of pyTorch %s does not support `IterableDataset`,"
" please upgrade to 1.2+" % torch.__version__,
ImportWarning,
)
EXIST_ITER_DATASET = False
else:
EXIST_ITER_DATASET = True
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
class Ba3lModule(pl.LightningModule):
def __init__(self, experiment):
super(Ba3lModule, self).__init__()
self.experiment = experiment
self.run = experiment.current_run
self.config = DefaultMunch.fromDict(experiment.current_run.config)
for key,model in experiment.current_run.config['models'].items():
setattr(self, key, experiment.current_run.get_command_function("models."+key+"."+model['instance_cmd'])())
self.save_hyperparameters(self.config)
================================================
FILE: ba3l/plutils/__init__.py
================================================
================================================
FILE: ba3l/plutils/lr_monitor.py
================================================
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Learning Rate Monitor
=====================
Monitor and logs learning rate for lr schedulers during training.
"""
from typing import Dict, List, Optional
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class LearningRateMonitor(Callback):
r"""
Automatically monitor and logs learning rate for learning rate schedulers during training.
Args:
logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
at the same interval, set to ``None`` to log at individual interval
according to the ``interval`` key of each scheduler. Defaults to ``None``.
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
Raises:
MisconfigurationException:
If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateMonitor
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
>>> trainer = Trainer(callbacks=[lr_monitor])
Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named ``Adam``,
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schdulers
Example::
def configure_optimizer(self):
optimizer = torch.optim.Adam(...)
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
'name': 'my_logging_name'
}
return [optimizer], [lr_scheduler]
"""
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
if logging_interval not in (None, 'step', 'epoch'):
raise MisconfigurationException('logging_interval should be `step` or `epoch` or `None`.')
self.logging_interval = logging_interval
self.log_momentum = log_momentum
self.lrs = None
self.lr_sch_names = []
def on_train_start(self, trainer, *args, **kwargs):
"""
Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
Raises:
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
raise MisconfigurationException(
'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.'
)
if not trainer.lr_schedulers:
rank_zero_warn(
'You are using `LearningRateMonitor` callback with models that'
' have no learning rate schedulers. Please see documentation'
' for `configure_optimizers` method.', RuntimeWarning
)
if self.log_momentum:
def _check_no_key(key):
return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers)
if _check_no_key('momentum') and _check_no_key('betas'):
rank_zero_warn(
"You have set log_momentum=True, but some optimizers do not"
" have momentum. This will log a value 0 for the momentum.", RuntimeWarning
)
# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
# Initialize for storing values
self.lrs = {name: [] for name in names}
self.last_momentum_values = {name + "-momentum": None for name in names}
def on_train_batch_start(self, trainer, *args, **kwargs):
if not self._should_log(trainer):
return
if self.logging_interval != 'epoch':
interval = 'step' if self.logging_interval is None else 'any'
latest_stat = self._extract_stats(trainer, interval)
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
def on_train_epoch_start(self, trainer, *args, **kwargs):
if self.logging_interval != 'step':
interval = 'epoch' if self.logging_interval is None else 'any'
latest_stat = self._extract_stats(trainer, interval)
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
latest_stat = {}
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler['interval'] == interval or interval == 'any':
opt = scheduler['scheduler'].optimizer
param_groups = opt.param_groups
use_betas = 'betas' in opt.defaults
for i, pg in enumerate(param_groups):
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
latest_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
)
latest_stat.update(momentum)
return latest_stat
def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
lr = param_group.get('lr')
self.lrs[name].append(lr)
return {name: lr}
def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
if not self.log_momentum:
return {}
momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
self.last_momentum_values[name] = momentum
return {name: momentum}
def _find_names(self, lr_schedulers) -> List[str]:
# Create uniqe names in the case we have multiple of the same learning
# rate schduler + multiple parameter groups
names = []
for scheduler in lr_schedulers:
sch = scheduler['scheduler']
if scheduler['name'] is not None:
name = scheduler['name']
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
i, name = 1, opt_name
# Multiple schduler of the same type
while True:
if name not in names:
break
i, name = i + 1, f'{opt_name}-{i}'
# Multiple param groups for the same schduler
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = f'{name}/pg{i + 1}'
names.append(temp)
else:
names.append(name)
self.lr_sch_names.append(name)
return names
@staticmethod
def _should_log(trainer) -> bool:
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
return should_log
================================================
FILE: ba3l/plutils/progress_bar.py
================================================
import importlib
import io
import os
import sys
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from typing import Optional, Union
from pytorch_lightning.callbacks import Callback,ProgressBarBase , ProgressBar as PlProgressBar
from pytorch_lightning.callbacks.progress import tqdm
class ProgressBar(PlProgressBar):
r"""
This is the default progress bar used by Lightning. It prints to `stdout` using the
:mod:`tqdm` package and shows up to four different bars:
- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for
multiple validation runs during training when
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation;
shows total progress over all validation datasets.
- **test progress:** only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the
:class:`~pytorch_lightning.trainer.trainer.Trainer`:
Example::
class LitProgressBar(ProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.set_description('running validation ...')
return bar
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
Args:
refresh_rate:
Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display. By default, the
:class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress
bar and sets the refresh rate to the value provided to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
process_position:
Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
together. This corresponds to
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
"""
def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.main_progress_bar is not None
has_main_bar = False
bar = tqdm(
desc='Validating',
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout
)
return bar
def on_epoch_start(self, trainer, pl_module):
ProgressBarBase.on_epoch_start(self, trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_val_batches = 0
total_batches = total_train_batches + total_val_batches
reset(self.main_progress_bar, total_batches)
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
ProgressBarBase.on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
def on_validation_start(self, trainer, pl_module):
ProgressBarBase.on_validation_start(self, trainer, pl_module)
if trainer.sanity_checking:
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
else:
if self.main_progress_bar is not None:
self.main_progress_bar.close()
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, self.total_val_batches)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
ProgressBarBase.on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
#self._update_bar(self.main_progress_bar)
def on_validation_end(self, trainer, pl_module):
ProgressBarBase.on_validation_end(self, trainer, pl_module)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.val_progress_bar.close()
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
return None
return x
def reset(bar: tqdm, total: Optional[int] = None) -> None:
""" Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """
if not bar.disable:
bar.reset(total=convert_inf(total))
================================================
FILE: ba3l/util/__init__.py
================================================
================================================
FILE: ba3l/util/functions.py
================================================
import inspect
from collections import OrderedDict
def get_default_kwargs_dict(f):
sig = inspect.signature(f)
return OrderedDict(
[
(p.name, p.default)
for p in sig.parameters.values()
if p.default != inspect._empty
]
)
================================================
FILE: ba3l/util/sacred_logger.py
================================================
from pytorch_lightning.utilities import rank_zero_only
try:
from pytorch_lightning.loggers import Logger as LightningLoggerBase
from pytorch_lightning.loggers.logger import rank_zero_experiment
except ImportError:
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from warnings import warn
from logging import getLogger
try:
import sacred
except ImportError:
raise ImportError("Missing sacred package. Run `pip install sacred`")
logger = getLogger(__name__)
class SacredLogger(LightningLoggerBase):
def __init__(self, sacred_experiment):
"""Initialize a sacred logger.
:param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers
already appended.
source: https://github.com/expectopatronum/pytorch-lightning/blob/9fcb238ec03e3f0b0378fd058119f1563a11650c/pytorch_lightning/logging/sacred.py
"""
super().__init__()
self.sacred_experiment = sacred_experiment
self.experiment_name = sacred_experiment.path
self._run_id = None
warn('SacredLogger is deprecated', DeprecationWarning, stacklevel=2)
@property
def experiment(self):
return self.sacred_experiment
@property
def run_id(self):
if self._run_id is not None:
return self._run_id
self._run_id = self.sacred_experiment.get_run_identifier()
return self._run_id
@rank_zero_only
def log_hyperparams(self, params):
# probably not needed bc. it is dealt with by sacred
pass
@rank_zero_only
def log_metrics(self, metrics, step=None):
for k, v in metrics.items():
if isinstance(v, str):
logger.warning(f"Discarding metric with string value {k}={v}")
continue
self.experiment.log_scalar(k, v, step)
@property
def name(self):
return self.experiment_name
@property
def version(self):
return self.run_id
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
================================================
FILE: config_updates.py
================================================
from sacred.config_helpers import DynamicIngredient, CMD
def add_configs(ex):
'''
This functions add generic configuration for the experiments, such as mix-up, architectures, etc...
@param ex: Ba3l Experiment
@return:
'''
@ex.named_config
def nomixup():
'Don\'t apply mix-up (spectrogram level).'
use_mixup = False
mixup_alpha = 0.3
@ex.named_config
def mixup():
' Apply mix-up (spectrogram level).'
use_mixup = True
mixup_alpha = 0.3
@ex.named_config
def mini_train():
'limit training/validation to 5 batches for debbuging.'
trainer = dict(limit_train_batches=5, limit_val_batches=5)
@ex.named_config
def passt():
'use PaSST model'
models = {
"net": DynamicIngredient("models.passt.model_ing")
}
@ex.named_config
def passt_s_20sec():
'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 20 seconds'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_20sec_p16_s10_ap474", fstride=10,
tstride=10, input_tdim=2000)
}
basedataset = dict(clip_length=20)
@ex.named_config
def passt_s_30sec():
'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 30 seconds'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_30sec_p16_s10_ap473", fstride=10,
tstride=10, input_tdim=3000)
}
basedataset = dict(clip_length=20)
@ex.named_config
def passt_s_ap476():
'use PaSST model pretrained on Audioset (with SWA) ap=476'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap476", fstride=10,
tstride=10)
}
@ex.named_config
def passt_s_ap4763():
'use PaSST model pretrained on Audioset (with SWA) ap=4763'
# test with: python ex_audioset.py evaluate_only with passt_s_ap4763
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap4763", fstride=10,
tstride=10)
}
@ex.named_config
def passt_s_ap472():
'use PaSST model pretrained on Audioset (no SWA) ap=472'
# test with: python ex_audioset.py evaluate_only with passt_s_ap472
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_128_ap472", fstride=10,
tstride=10)
}
@ex.named_config
def passt_s_p16_s16_128_ap468():
'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap'
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s16_128_ap468", fstride=16,
tstride=16)
}
@ex.named_config
def passt_s_swa_p16_s16_128_ap473():
'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap'
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s16_128_ap473", fstride=16,
tstride=16)
}
@ex.named_config
def passt_s_swa_p16_s14_128_ap471():
'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 '
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s14_128_ap471", fstride=14,
tstride=14)
}
@ex.named_config
def passt_s_p16_s14_128_ap469():
'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 '
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s14_128_ap469", fstride=14,
tstride=14)
}
@ex.named_config
def passt_s_swa_p16_s12_128_ap473():
'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 '
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s12_128_ap473", fstride=12,
tstride=12)
}
@ex.named_config
def passt_s_p16_s12_128_ap470():
'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 '
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s12_128_ap470", fstride=12,
tstride=12)
}
@ex.named_config
def ensemble_s10():
'use ensemble of PaSST models pretrained on Audioset with S10 mAP=.4864'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s10", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
]
)
}
@ex.named_config
def ensemble_many():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4956'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
("passt_s_p16_s12_128_ap470", 12, 12),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_p16_s14_128_ap469", 14, 14),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
("passt_s_p16_s16_128_ap468", 16, 16),
]
)
}
@ex.named_config
def ensemble_4():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4926'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
)
}
@ex.named_config
def ensemble_5():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.49459'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
)
}
@ex.named_config
def ensemble_s16_14():
'use ensemble of two PaSST models pretrained on Audioset with stride 16 and 14 mAP=.48579'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s16", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
)
}
@ex.named_config
def dynamic_roll():
# dynamically roll the spectrograms/waveforms
# updates the dataset config
basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000)
)
# extra commands
@ex.command
def test_loaders_train_speed():
# test how fast data is being loaded from the data loaders.
itr = ex.datasets.training.get_iter()
import time
start = time.time()
print("hello")
for i, b in enumerate(itr):
if i % 20 == 0:
print(f"{i}/{len(itr)}", end="\r")
end = time.time()
print("totoal time:", end - start)
start = time.time()
print("retry:")
for i, b in enumerate(itr):
if i % 20 == 0:
print(f"{i}/{len(itr)}", end="\r")
end = time.time()
print("totoal time:", end - start)
================================================
FILE: environment.yml
================================================
name: ba3l
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- ca-certificates=2022.10.11
- certifi=2022.12.7
- ld_impl_linux-64=2.38
- libffi=3.4.2
- libgcc-ng=11.2.0
- libgomp=11.2.0
- libstdcxx-ng=11.2.0
- ncurses=6.3
- openssl=1.1.1s
- pip=22.3.1
- python=3.8.16
- readline=8.2
- setuptools=65.6.3
- sqlite=3.40.1
- tk=8.6.12
- wheel=0.37.1
- xz=5.2.10
- zlib=1.2.13
- pip:
- absl-py==1.4.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- appdirs==1.4.4
- async-timeout==4.0.2
- attrs==22.2.0
- audioread==3.0.0
- av==10.0.0
- cachetools==5.3.0
- cffi==1.15.1
- charset-normalizer==3.0.1
- colorama==0.4.6
- decorator==5.1.1
- docopt==0.6.2
- frozenlist==1.3.3
- fsspec==2023.3.0
- future==0.18.3
- gitdb==4.0.10
- gitpython==3.1.31
- google-auth==2.17.0
- google-auth-oauthlib==0.4.6
- grpcio==1.53.0
- h5py==3.8.0
- idna==3.4
- imageio==2.27.0
- importlib-metadata==6.1.0
- joblib==1.2.0
- jsonpickle==3.0.1
- kk-sacred==0.8.4
- lazy-loader==0.2
- librosa==0.10.0.post2
- llvmlite==0.39.1
- markdown==3.4.3
- markupsafe==2.1.2
- msgpack==1.0.5
- multidict==6.0.4
- munch==2.5.0
- numba==0.56.4
- numpy==1.23.5
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- oauthlib==3.2.2
- packaging==23.0
- pandas==1.5.3
- pillow==9.4.0
- pooch==1.6.0
- protobuf==4.22.1
- py-cpuinfo==9.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycparser==2.21
- python-dateutil==2.8.2
- pytorch-lightning==1.3.0.dev0
- pytz==2023.3
- pyyaml==6.0
- requests==2.28.2
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.2.2
- scipy==1.10.1
- six==1.16.0
- smmap==5.0.0
- soundfile==0.12.1
- soxr==0.3.4
- tensorboard==2.12.0
- tensorboard-data-server==0.7.0
- tensorboard-plugin-wit==1.8.1
- test-tube==0.7.5
- threadpoolctl==3.1.0
- timm==0.4.12
- torch==1.13.1
- torchaudio==0.13.1
- torchmetrics==0.2.0
- torchvision==0.14.1
- tqdm==4.65.0
- typing-extensions==4.4.0
- urllib3==1.26.14
- werkzeug==2.2.3
- wrapt==1.15.0
- yarl==1.8.2
- zipp==3.15.0
================================================
FILE: environment_old_2021.yml
================================================
name: ba3l
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=4.5
- _pytorch_select=0.1
- appdirs=1.4.4
- audioread=2.1.9
- blas=1.0
- brotlipy=0.7.0
- bzip2=1.0.8
- c-ares=1.17.1
- ca-certificates=2020.12.5
- cached-property=1.5.2
- cached_property=1.5.2
- certifi=2020.12.5
- cffi=1.14.5
- chardet=4.0.0
- colorama=0.4.4
- cryptography=3.4.6
- cycler=0.10.0
- decorator=4.4.2
- docopt=0.6.2
- ffmpeg=4.3.1
- freetype=2.10.4
- gettext=0.19.8.1
- gitdb=4.0.5
- gitpython=3.1.14
- gmp=6.2.1
- gnutls=3.6.13
- h5py=3.1.0
- hdf5=1.10.6
- idna=2.10
- importlib-metadata=3.7.3
- importlib_metadata=3.7.3
- intel-openmp=2020.2
- joblib=1.0.1
- jpeg=9d
- jsonpickle=1.4.1
- kiwisolver=1.3.1
- krb5=1.17.2
- lame=3.100
- lcms2=2.12
- ld_impl_linux-64=2.35.1
- libblas=3.9.0
- libcblas=3.9.0
- libcurl=7.75.0
- libedit=3.1.20191231
- libev=4.33
- libffi=3.3
- libflac=1.3.3
- libgcc-ng=9.3.0
- libgfortran-ng=9.3.0
- libgfortran5=9.3.0
- libgomp=9.3.0
- liblapack=3.9.0
- libllvm10=10.0.1
- libnghttp2=1.43.0
- libogg=1.3.4
- libopenblas=0.3.12
- libopus=1.3.1
- libpng=1.6.37
- librosa=0.8.0
- libsndfile=1.0.31
- libssh2=1.9.0
- libstdcxx-ng=9.3.0
- libtiff=4.2.0
- libvorbis=1.3.7
- libwebp-base=1.2.0
- llvm-openmp=11.1.0
- llvmlite=0.36.0
- lz4-c=1.9.3
- matplotlib-base=3.3.4
- mkl=2020.2
- mkl-service=2.3.0
- munch=2.5.0
- ncurses=6.2
- nettle=3.6
- ninja=1.10.2
- numba=0.53.0
- numpy=1.20.1
- olefile=0.46
- openblas=0.3.12
- openh264=2.1.1
- openssl=1.1.1k
- packaging=20.9
- pandas=1.2.3
- pillow=8.1.2
- pip=21.0.1
- pooch=1.3.0
- py-cpuinfo=7.0.0
- pycparser=2.20
- pyopenssl=20.0.1
- pyparsing=2.4.7
- pysocks=1.7.1
- pysoundfile=0.10.3.post1
- python=3.7.10
- python-dateutil=2.8.1
- python_abi=3.7
- pytz=2021.1
- readline=8.0
- requests=2.25.1
- resampy=0.2.2
- scikit-learn=0.24.1
- scipy=1.6.1
- setuptools=49.6.0
- six=1.15.0
- smmap=3.0.5
- sqlite=3.34.0
- threadpoolctl=2.1.0
- tk=8.6.10
- tornado=6.1
- typing_extensions=3.7.4.3
- urllib3=1.26.4
- wrapt=1.12.1
- x264=1!161.3030
- xz=5.2.5
- zipp=3.4.1
- zlib=1.2.11
- zstd=1.4.9
- pip:
- absl-py==0.12.0
- aiohttp==3.7.4.post0
- async-timeout==3.0.1
- attrs==20.3.0
- av==8.0.3
- black==20.8b1
- cachetools==4.2.1
- click==7.1.2
- einops==0.3.0
- fsspec==0.8.7
- future==0.18.2
- google-auth==1.28.0
- google-auth-oauthlib==0.4.3
- gpuinfo==1.0.0a7
- grpcio==1.36.1
- imageio==2.9.0
- jedi==0.18.0
- libtmux==0.8.5
- markdown==3.3.4
- multidict==5.1.0
- mypy-extensions==0.4.3
- oauthlib==3.1.0
- parso==0.8.2
- pathspec==0.8.1
- prompt-toolkit==3.0.18
- protobuf==3.15.6
- ptpython==3.0.17
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydub==0.25.1
- pygments==2.8.1
- pymongo==3.11.3
- pyyaml==5.3.1
- regex==2021.4.4
- requests-oauthlib==1.3.0
- rsa==4.7.2
- tensorboard==2.4.1
- tensorboard-plugin-wit==1.8.0
- test-tube==0.7.5
- timm==0.4.12
- toml==0.10.2
- torch==1.8.1+cu111
- torchaudio==0.8.1
- torchmetrics==0.2.0
- torchvision==0.6.0
- tqdm==4.59.0
- typed-ast==1.4.3
- wcwidth==0.2.5
- werkzeug==1.0.1
- wheel==0.36.2
- yarl==1.6.3
================================================
FILE: esc50/README.md
================================================
# Experiments on ESC-50 Environmental Sound Classification
[ESC-50](https://github.com/karolpiczak/ESC-50) consist of 2000 5-second recordings to be classified to 50 semantical classes.
## Setting up the fine-tuning experiments
- Download the prerpocessed (resampled) dataset [esc50.zip](https://github.com/kkoutini/PaSST/releases/download/v.0.0.6/esc50.zip) from the [releases page](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6) and
unpack the zip file into a directory (the default path is `./audioset_hdf5s/`). The `base_dir` config in the dataset file ([here](https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py#L35)) should point to the extracted contents of the dataset zip file.
- Running the experiments using the common configurations (similar to Audioset)
```shell
python3 ex_esc50.py with models.net.s_patchout_t=10 models.net.s_patchout_f=5 basedataset.fold=1 -p
```
## Pre-trained models
Pre-trained models on ESC-50 can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6).
In order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 :
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# model wrapper, includes Melspectrogram and the default transformer
model = get_basic_model(mode="logits")
# replace the transformer with one that outputs 50 classes
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50)
# load the pre-trained model state dict
state_dict = torch.load('/home/khaled/esc50-passt-s-n-f128-p16-s10-fold1-acc.967.pt')
# load the weights into the transformer
model.net.load_state_dict(state_dict)
# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
# audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
logits=model(audio_wave)
```
================================================
FILE: esc50/dataset.py
================================================
import io
import os
import pathlib
import random
import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler
import torch
from ba3l.ingredients.datasets import Dataset
import pandas as pd
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
from sklearn import preprocessing
from torch.utils.data import Dataset as TorchDataset
import numpy as np
import h5py
from helpers.audiodatasets import PreprocessDataset
LMODE = os.environ.get("LMODE", False)
dataset = Dataset('Esc50')
@dataset.config
def default_config():
name = 'esc50' # dataset name
normalize = False # normalize dataset
subsample = False # subsample squares from the dataset
roll = True # apply roll augmentation
fold = 1
base_dir = "audioset_hdf5s/esc50/" # base directory of the dataset as downloaded
if LMODE:
base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/esc50/"
meta_csv = base_dir + "meta/esc50.csv"
audio_path = base_dir + "audio_32k/"
ir_path = base_dir + "irs/"
num_of_classes = 50
def decode_mp3(mp3_arr):
"""
decodes an array if uint8 representing an mp3 file
:rtype: np.array
"""
container = av.open(io.BytesIO(mp3_arr.tobytes()))
stream = next(s for s in container.streams if s.type == 'audio')
# print(stream)
a = []
for i, packet in enumerate(container.demux(stream)):
for frame in packet.decode():
a.append(frame.to_ndarray().reshape(-1))
waveform = np.concatenate(a)
if waveform.dtype != 'float32':
raise RuntimeError("Unexpected wave type")
return waveform
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
else:
return x[0: audio_length]
irs_arr = None
@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
if not ir_augment:
return
global irs_arr
if irs_arr is None:
all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
if cut_irs_offset is not None:
all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
print("will use these IRs:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])
_run.info["ir_devices"] = all_paths_name
irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
return irs_arr[int(np.random.randint(0, len(irs_arr)))]
@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
if ir_augment and torch.rand(1) < ir_augment:
ir = get_ir_sample()
waveform = convolve(waveform, ir, 'full')
if gain_augment:
gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
amp = 10 ** (gain / 20)
waveform = waveform * amp
return waveform
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
if torch.rand(1) < self.rate:
x1, f1, y1 = self.dataset[index]
idx2 = torch.randint(len(self.dataset), (1,)).item()
x2, f2, y2 = self.dataset[idx2]
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1 - x1.mean()
x2 = x2 - x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
return x, f1, (y1 * l + y2 * (1. - l))
return self.dataset[index]
def __len__(self):
return len(self.dataset)
class AudioSetDataset(TorchDataset):
def __init__(self, meta_csv, audiopath, fold, train=False, sample_rate=32000, classes_num=527,
clip_length=5, augment=False):
"""
Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
"""
self.sample_rate = sample_rate
self.meta_csv = meta_csv
self.df = pd.read_csv(meta_csv)
if train: # training all except this
print(f"Dataset training fold {fold} selection out of {len(self.df)}")
self.df = self.df[self.df.fold != fold]
print(f" for training remains {len(self.df)}")
else:
print(f"Dataset testing fold {fold} selection out of {len(self.df)}")
self.df = self.df[self.df.fold == fold]
print(f" for testing remains {len(self.df)}")
self.clip_length = clip_length * sample_rate
self.sr = sample_rate
self.classes_num = classes_num
self.augment = augment
self.audiopath=audiopath
if augment:
print(f"Will agument data from {meta_csv}")
def __len__(self):
return len(self.df)
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
meta: {
'hdf5_path': str,
'index_in_hdf5': int}
Returns:
data_dict: {
'audio_name': str,
'waveform': (clip_samples,),
'target': (classes_num,)}
"""
row = self.df.iloc[index]
#waveform = decode_mp3(np.fromfile(self.audiopath + row.filename, dtype='uint8'))
waveform, _ = librosa.load(self.audiopath + row.filename, sr=self.sr, mono=True)
if self.augment:
waveform = pydub_augment(waveform)
waveform = pad_or_truncate(waveform, self.clip_length)
waveform = self.resample(waveform)
target = row.target
return waveform.reshape(1, -1), row.filename, target
def resample(self, waveform):
"""Resample.
Args:
waveform: (clip_samples,)
Returns:
(resampled_clip_samples,)
"""
if self.sample_rate == 32000:
return waveform
elif self.sample_rate == 16000:
return waveform[0:: 2]
elif self.sample_rate == 8000:
return waveform[0:: 4]
else:
raise Exception('Incorrect sample rate!')
@dataset.command
def get_base_training_set(meta_csv, audio_path, fold=1):
ds = AudioSetDataset(meta_csv, audio_path, fold, train=True, augment=True)
return ds
@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
epoch_len=100000, sampler_replace=False):
num_nodes = int(os.environ.get('num_nodes', 1))
ddp = int(os.environ.get('DDP', 1))
num_nodes = max(ddp, num_nodes)
print("num_nodes= ", num_nodes)
rank = int(os.environ.get('NODE_RANK', 0))
return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
num_samples=epoch_len, replacement=sampler_replace),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
@dataset.command
def get_base_test_set(meta_csv, audio_path, fold=1):
ds = AudioSetDataset(meta_csv, audio_path, fold, train=False)
return ds
@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
print("rolling...")
def roll_func(b):
x, i, y = b
x = torch.as_tensor(x)
sf = shift
if shift is None:
sf = int(np.random.random_integers(-shift_range, shift_range))
global FirstTime
return x.roll(sf, axis), i, y
return roll_func
@dataset.command
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_test_set(normalize):
ds = get_base_test_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def print_conf(_config):
print("Config of ", dataset.path, id(dataset))
print(_config)
print()
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset,
num_replicas=None,
rank=None,
shuffle: bool = True):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch == 0:
print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n")
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
if __name__ == "__main__":
from sacred import Experiment
ex = Experiment("test_dataset", ingredients=[dataset])
@ex.automain
def default_command():
ex.current_run.get_command_function("print_config")()
get_base_training_set()
ds = get_test_set()
print(ds[0])
ds = get_training_set()
print(ds[0])
print("get_base_training_set", len(get_base_training_set()))
print("get_base_test_set", len(get_base_test_set()))
print("get_training_set", len(get_training_set()))
print("get_test_set", len(get_test_set()))
================================================
FILE: ex_audioset.py
================================================
import os
import sys
import PIL
import pytorch_lightning
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
import wandb
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
ex = Experiment("audioset")
# Example call with all the default config:
# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# with 2 gpus:
# DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
# now you can use in the cmd trainer.precision=16 for example
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
wandb_logger = None
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_full_training_set"),
sampler=CMD("/basedataset.get_ft_weighted_sampler"))
get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
dataset=CMD("/basedataset.get_test_set"))
@ex.config
def default_conf():
cmd = " ".join(sys.argv) # command line arguments
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_deit_bd_p16_384", n_classes=527, s_patchout_t=40,
s_patchout_f=4), # network config
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
basedataset = DynamicIngredient("audioset.dataset.dataset", wavmix=1)
wandb = dict(project="passt_audioset2", log_model=True)
watch_model = True
trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, precision=16,
reload_dataloaders_every_epoch=True)
lr = 0.00002 # learning rate
use_mixup = True
mixup_alpha = 0.3
compile = True # compile the model, requires pytorch >= 2.0
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(
f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.distributed_mode = self.config.trainer.num_nodes > 1
if self.config.compile:
# pt 2 magic
print("\n\nCompiling the model pytorch 2... \n\n")
self.net = torch.compile(self.net)
# compile only the net, not the mel
#self.mel = torch.compile(self.mel)
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
if self.global_step < 5:
images = [ wandb.Image(
PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"),
caption="spectrograms",
) for i in x[:, 0, :, :].cpu().numpy()]
wandb.log({"spectrograms": images})
# wandb_logger.log_image(key="spectrograms", images=[i for i in x[:,0,:,:].cpu().numpy()])
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + \
x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + \
y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
self.log('trainer/lr', self.trainer.optimizers[0].param_groups[0]['lr'])
self.log('epoch', self.current_epoch)
self.log("training.loss", loss.detach())
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, # 'step': self.current_epoch
}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
if self.global_step < 5:
images = [ wandb.Image(
PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"),
caption="validation_spectrograms",
) for i in x[:, 0, :, :].cpu().numpy()]
wandb.log({"validation_spectrograms": images})
# wandb_logger.log_image(key="validation_spectrograms", images=[
# i for i in x[:, 0, :, :].cpu().numpy()])
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss,
net_name + "out": out, net_name + "target": y.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss']
for x in outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
target = torch.cat([x[net_name + 'target']
for x in outputs], dim=0)
try:
average_precision = metrics.average_precision_score(
target.float().numpy(), out.float().numpy(), average=None)
except ValueError:
average_precision = np.array([np.nan] * 527)
try:
roc = metrics.roc_auc_score(
target.numpy(), out.numpy(), average=None)
except ValueError:
roc = np.array([np.nan] * 527)
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),}
#'step': torch.as_tensor(self.current_epoch).cuda()}
# torch.save(average_precision,
# f"ap_perclass_{average_precision.mean()}.pt")
# print(average_precision)
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
alltarget = self.all_gather(target)
average_precision = metrics.average_precision_score(
alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),
allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)
if self.trainer.is_global_zero:
logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
# 'step': torch.as_tensor(self.current_epoch).cuda()
}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict(
{net_name + "allap": logs[net_name + 'ap'],
#'step': logs['step']
}
, sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback() + [LearningRateMonitor(logging_interval='epoch')]
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
if save_last_n is None:
return []
return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=50,
swa_freq=5):
if not swa:
return []
print("\n Using swa!\n")
if pytorch_lightning.__version__ <= "1.5":
from helpers.swa_legacy import StochasticWeightAveraging
else:
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed, watch_model=True):
global wandb_logger
wandb_logger = get_logger()
trainer = get_trainer(logger=wandb_logger)
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
if watch_model:
# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(modul, log_freq=1000, log="all" )
if pytorch_lightning.__version__ <= "1.5":
trainer.fit(
modul,
train_dataloader=train_loader,
val_dataloaders=val_loader,
)
else:
trainer.fit(
modul,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=12):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
net = torch.compile(net)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(
y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(
y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) /
(t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer(logger=get_logger())
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
'''
read the dataset sequentially, useful if you have a network cache
@param all_y: the dataset preload command
@return:
'''
print(all_y.shape)
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[
rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + \
[f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
# plz no collisions
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}"
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: ex_esc50.py
================================================
import os
import sys
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
ex = Experiment("passt_esc50")
# Example call with all the default config:
# python ex_esc50.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base"
# with 2 gpus:
# DDP=2 python ex_esc50.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base"
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"),
)
get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
dataset=CMD("/basedataset.get_test_set"))
@ex.config
def default_conf():
cmd = " ".join(sys.argv)
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", n_classes=50, s_patchout_t=10, s_patchout_f=3),
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
timem=80,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
basedataset = DynamicIngredient("esc50.dataset.dataset")
trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True)
wandb = dict(project="passt_esc50", log_model=True)
lr = 0.00001
use_mixup = True
mixup_alpha = 0.3
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.distributed_mode = self.config.trainer.num_nodes > 1
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
# y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = (F.cross_entropy(y_hat, y, reduction="none") * lam.reshape(batch_size) +
F.cross_entropy(y_hat, y[rn_indices], reduction="none") * (1. - lam.reshape(batch_size)))
loss = samples_loss.mean()
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.cross_entropy(y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, 'step': self.current_epoch}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.cross_entropy(y_hat, y, reduction="none")
loss = samples_loss.mean()
_, preds = torch.max(y_hat, dim=1)
n_correct_pred_per_sample = (preds == y)
n_correct_pred = n_correct_pred_per_sample.sum()
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss,
net_name + "n_correct_pred": torch.as_tensor(n_correct_pred), net_name + "n_pred":torch.as_tensor( len(y)) }
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'acc': torch.as_tensor(val_acc).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback()
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
if save_last_n is None:
return []
return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=2,
swa_freq=1):
if not swa:
return []
print("\n Using swa!\n")
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed):
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
trainer.fit(
modul,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: ex_fsd50k.py
================================================
import os
import sys
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
ex = Experiment("fsd50k")
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
# Example call with all the default config:
# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# with 2 gpus:
# DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=True, dataset=CMD("/basedataset.get_training_set"),
)
get_validate_loader = ex.datasets.valid.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=10, num_workers=16,
dataset=CMD("/basedataset.get_valid_set"))
get_eval_loader = ex.datasets.eval.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=10, num_workers=16,
dataset=CMD("/basedataset.get_eval_set"))
@ex.named_config
def variable_eval():
basedataset = dict(variable_eval=True)
datasets = dict(valid=dict(batch_size=1), eval=dict(batch_size=1))
@ex.config
def default_conf():
cmd = " ".join(sys.argv) # command line arguments
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing",
n_classes=200, s_patchout_t=10, s_patchout_f=4), # network config
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=0,
timem=0,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
# set the default name for wandb logger
wandb = dict(project="passt_fsd50k", log_model=True)
basedataset = DynamicIngredient("fsd50k.dataset.dataset", wavmix=1)
trainer = dict(max_epochs=50, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True)
lr = 0.00001 # learning rate
use_mixup = True
mixup_alpha = 0.3
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_len=10, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.valid_names = ["valid", "eval"]
self.distributed_mode = self.config.trainer.num_nodes > 1
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, 'step': self.current_epoch}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx, dataloader_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
for idx, one_outputs in enumerate(outputs):
set_name = self.valid_names[idx] + "_"
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in one_outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in one_outputs], dim=0)
target = torch.cat([x[net_name + 'target'] for x in one_outputs], dim=0)
try:
average_precision = metrics.average_precision_score(
target.float().numpy(), out.float().numpy(), average=None)
except ValueError:
average_precision = np.array([np.nan] * 200)
try:
roc = metrics.roc_auc_score(target.numpy(), out.numpy(), average=None)
except ValueError:
roc = np.array([np.nan] * 200)
logs = {set_name + net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
set_name + net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
set_name + net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
# torch.save(average_precision, f"ap_perclass_{average_precision.mean()}.pt")
# print(average_precision)
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
alltarget = self.all_gather(target)
average_precision = metrics.average_precision_score(
alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),
allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)
if self.trainer.is_global_zero:
logs = {set_name + net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict(
{set_name + net_name + "allap": logs[set_name + net_name + 'ap'], 'step': logs['step']},
sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback()
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
model_checkpoint_callback = None
@ex.command
def get_extra_checkpoint_callback(save_best=None):
if save_best is None:
return []
global model_checkpoint_callback
model_checkpoint_callback = ModelCheckpoint(monitor="allap", verbose=True, save_top_k=save_best, mode='max',
every_n_val_epochs=1, every_n_train_steps=0)
return [model_checkpoint_callback]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=10,
swa_freq=3):
if not swa:
return []
print("\n Using swa!\n")
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed):
trainer = get_trainer(logger=get_logger())
train_loader = get_train_loader()
val_loader = get_validate_loader()
eval_loader = get_eval_loader()
# eval_loader = get_eval_loader()
modul = M(ex)
trainer.fit(
modul,
train_dataloader=train_loader,
val_dataloaders=[val_loader, eval_loader],
)
## evaluate best model on eval set
#trainer.val_dataloaders = None
modul.val_dataloaders = None
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
'''
read the dataset sequentially, useful if you have a network cache
@param all_y: the dataset preload command
@return:
'''
print(all_y.shape)
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: ex_openmic.py
================================================
import os
import sys
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
ex = Experiment("openmic")
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
# Example call with all the default config:
# python ex_openmic.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base"
# with 2 gpus:
# DDP=2 python ex_openmic.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base"
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=6,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"),
)
get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
dataset=CMD("/basedataset.get_test_set"))
@ex.config
def default_conf():
cmd = " ".join(sys.argv)
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", n_classes=20, s_patchout_t=40, s_patchout_f=4),
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
wandb = dict(project="passt_openmic", log_model=True)
basedataset = DynamicIngredient("openmic.dataset.dataset", wavmix=1)
# set the default for the trainer
trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True)
lr = 0.00001
use_mixup = True
mixup_alpha = 0.3
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.distributed_mode = self.config.trainer.num_nodes > 1
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_mask = y[:, 20:]
y = y[:, :20] > 0.5
y = y.float()
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
y_mix_mask = ((y_mask > 0.5) | (y_mask[rn_indices] > 0.5)).float()
samples_loss = y_mask.float() * samples_loss
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
samples_loss = y_mask.float() * samples_loss
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, 'step': self.current_epoch}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_mask = y[:, 20:]
y = y[:, :20] > 0.5
y = y.float()
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
samples_loss = y_mask.float() * samples_loss
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach(),
net_name + "mask": y_mask.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0)
mask = torch.cat([x[net_name + 'mask'] for x in outputs], dim=0)
try:
y_true = target.float().numpy()
y_pred = out.float().numpy()
y_mask = mask.float().numpy()
average_precision = np.array([metrics.average_precision_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
except ValueError:
average_precision = np.array([np.nan] * y_true.shape[1])
#torch.save(average_precision, f"ap_openmic_perclass_{average_precision.mean()}.pt")
try:
roc = np.array([metrics.roc_auc_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
except ValueError:
roc = np.array([np.nan] * y_true.shape[1])
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
alltarget = self.all_gather(target)
all_mask = self.all_gather(mask)
y_true = alltarget.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
y_pred = allout.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
y_mask = all_mask.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
average_precision = np.array([metrics.average_precision_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
if self.trainer.is_global_zero:
logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict({net_name + "allap": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback()
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
if save_last_n is None:
return []
return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=2,
swa_freq=1):
if not swa:
return []
print("\n Using swa!\n")
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed):
trainer = get_trainer(logger=get_logger())
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
trainer.fit(
modul,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
'''
read the dataset sequentially, useful if you have a network cache
@param all_y: the dataset preload command
@return:
'''
print(all_y.shape)
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: fsd50k/README.md
================================================
# Experiments on FSD50K
The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432)) consists of 51K audio clips annotated
with 200 sound event classes taken from the Audioset ontology. The dataset contains 100 hours of audio and is the
second largest publicly available general purpose sound event
recognition dataset after Audioset. Furthermore, the FSD50K
evaluation set is of high quality, with each evaluation label being double-checked and assessed by two to five independent annotators
# Setup
1. Download the dataset from [Zenodo](https://zenodo.org/record/4060432) and unzip it.
2. Convert wav files to mp3s:
```shell
cd fsd50k/prepare_scripts/
python convert_to_mp3.py path/to/fsd50k
```
this will create a folder inside the FSD50K directory with the mp3 files.
3. Pack the mp3 to HDF5 files:
```shell
cd fsd50k/prepare_scripts/
python create_h5pymp3_dataset.py path/to/fsd50k
```
Now you should have inside `../../audioset_hdf5s/mp3/` three new files: `FSD50K.eval_mp3.hdf`, `FSD50K.val_mp3.hdf`, `FSD50K.train_mp3.hdf`.
# Runing Experiments
Similar to the runs on Audioset, PaSST-S:
```shell
# Example call with all the default config:
python ex_fsd50k.py with trainer.precision=16 -p
```
```shell
# Example call without overlap:
python ex_fsd50k.py with passt_s_swa_p16_s16_128_ap473 models.net.s_patchout_t=10 models.net.s_patchout_f=1 trainer.precision=16 -p
```
# Pre-trained models
Pre-trained models on FSD50K can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v0.0.5).
In order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 :
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
# replace the transformer with one that outputs 200 classes
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=200)
# load the pre-trained model state dict with mAP of .655 on FSD50K
state_dict = torch.load('/home/khaled/fsd50k-passt-s-f128-p16-s10-ap.655.pt')
# load the weights into the transformer
model.net.load_state_dict(state_dict)
# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
# audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
logits=model(audio_wave)
```
Using the model with no patch overlap PaSST-S-N `fsd50k-passt-s-n-f128-p16-s16-ap.642.pt`:
```python
# replace the transformer with one that outputs 200 classes
model.net = get_model_passt(arch="passt_s_p16_s16_128_ap468", fstride=16,
tstride=16, n_classes=200)
# load the pre-trained model state dict with mAP of .642 on FSD50K with no patch overlap
state_dict = torch.load('/home/khaled/fsd50k-passt-s-n-f128-p16-s16-ap.642.pt')
# load the weights into the transformer
model.net.load_state_dict(state_dict)
```
================================================
FILE: fsd50k/dataset.py
================================================
import io
import os
import pathlib
import random
import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler
import torch
from ba3l.ingredients.datasets import Dataset
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
import numpy as np
from helpers.audiodatasets import PreprocessDataset
import h5py
LMODE = os.environ.get("LMODE", False)
# $TMPDIR
dataset = Dataset('audiodataset')
@dataset.config
def default_config():
name = 'audioset' # dataset name
normalize = False # normalize dataset
subsample = False # subsample squares from the dataset
roll = True # apply roll augmentation
fold = 1
base_dir = "audioset_hdf5s/" # base directory of the dataset, change it or make a link
if LMODE:
base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/"
balanced_train_hdf5 = base_dir + "mp3/FSD50K.train_mp3.hdf"
valid_hdf5 = base_dir + "mp3/FSD50K.val_mp3.hdf"
eval_hdf5 = base_dir + "mp3/FSD50K.eval_mp3.hdf"
if LMODE:
balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir) + "/")
eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir) + "/")
valid_hdf5 = valid_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir) + "/")
ir_path = base_dir + "irs/"
num_of_classes = 200
if LMODE:
@dataset.config
def LMODE_default_config():
cache_root_path = "/system/user/publicdata/CP/DCASE/cached_datasets/"
def decode_mp3(mp3_arr):
"""
decodes an array if uint8 representing an mp3 file
:rtype: np.array
"""
container = av.open(io.BytesIO(mp3_arr.tobytes()))
stream = next(s for s in container.streams if s.type == 'audio')
# print(stream)
a = []
for i, packet in enumerate(container.demux(stream)):
for frame in packet.decode():
a.append(frame.to_ndarray().reshape(-1))
waveform = np.concatenate(a)
if waveform.dtype != 'float32':
raise RuntimeError("Unexpected wave type")
return waveform
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if audio_length is None:
# audio_length not specified don't do anything.
return x
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
else:
offset = torch.randint(0, len(x) - audio_length + 1, (1,)).item()
return x[offset:offset + audio_length]
irs_arr = None
@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
if not ir_augment:
return
global irs_arr
if irs_arr is None:
all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
if cut_irs_offset is not None:
all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
print("will use these IRs:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])
_run.info["ir_devices"] = all_paths_name
irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
return irs_arr[int(np.random.randint(0, len(irs_arr)))]
@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
if ir_augment and torch.rand(1) < ir_augment:
ir = get_ir_sample()
waveform = convolve(waveform, ir, 'full')
if gain_augment:
gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
amp = 10 ** (gain / 20)
waveform = waveform * amp
return waveform
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
if torch.rand(1) < self.rate:
x1, f1, y1 = self.dataset[index]
idx2 = torch.randint(len(self.dataset), (1,)).item()
x2, f2, y2 = self.dataset[idx2]
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1 - x1.mean()
x2 = x2 - x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
return x, f1, (y1 * l + y2 * (1. - l))
return self.dataset[index]
def __len__(self):
return len(self.dataset)
class AudioSetDataset(TorchDataset):
def __init__(self, hdf5_file, sample_rate=32000, classes_num=200, clip_length=10, augment=False, in_mem=False):
"""
Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
"""
self.sample_rate = sample_rate
self.hdf5_file = hdf5_file
if in_mem:
print("\nPreloading in memory\n")
with open(hdf5_file, 'rb') as f:
self.hdf5_file = io.BytesIO(f.read())
with h5py.File(hdf5_file, 'r') as f:
self.length = len(f['audio_name'])
print(f"Dataset from {hdf5_file} with length {self.length}.")
self.dataset_file = None # lazy init
self.clip_length = clip_length
if clip_length is not None:
self.clip_length = clip_length * sample_rate
self.classes_num = classes_num
self.augment = augment
if augment:
print(f"Will agument data from {hdf5_file}")
def open_hdf5(self):
self.dataset_file = h5py.File(self.hdf5_file, 'r')
def __len__(self):
return self.length
def __del__(self):
if self.dataset_file is not None:
self.dataset_file.close()
self.dataset_file = None
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
meta: {
'hdf5_path': str,
'index_in_hdf5': int}
Returns:
data_dict: {
'audio_name': str,
'waveform': (clip_samples,),
'target': (classes_num,)}
"""
if self.dataset_file is None:
self.open_hdf5()
audio_name = self.dataset_file['audio_name'][index].decode()
waveform = decode_mp3(self.dataset_file['mp3'][index])
if self.augment:
waveform = pydub_augment(waveform)
waveform = pad_or_truncate(waveform, self.clip_length)
waveform = self.resample(waveform)
target = self.dataset_file['target'][index]
target = np.unpackbits(target, axis=-1,
count=self.classes_num).astype(np.float32)
return waveform.reshape(1, -1), audio_name, target
def resample(self, waveform):
"""Resample.
Args:
waveform: (clip_samples,)
Returns:
(resampled_clip_samples,)
"""
if self.sample_rate == 32000:
return waveform
elif self.sample_rate == 16000:
return waveform[0:: 2]
elif self.sample_rate == 8000:
return waveform[0:: 4]
else:
raise Exception('Incorrect sample rate!')
@dataset.command
def get_base_training_set(balanced_train_hdf5, clip_length=10):
ds = AudioSetDataset(balanced_train_hdf5, augment=True, clip_length=clip_length)
return ds
@dataset.command
def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes):
# Preload mp3 sequential from disk, OS will cache the chunks in memory.
# Useful if the hdf file is on a NFS mount, saving the random access.
for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
print(f"\n \n will now preload {hdf5_file} \n\n ")
with h5py.File(hdf5_file, 'r') as dataset_file:
target = dataset_file['mp3'][:]
print(len(target))
print(f"\n \n done with {hdf5_file} \n\n ")
return target[1000]
@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
epoch_len=100000, sampler_replace=False):
num_nodes = int(os.environ.get('num_nodes', 1))
ddp = int(os.environ.get('DDP', 1))
num_nodes = max(ddp, num_nodes)
print("num_nodes= ", num_nodes)
rank = int(os.environ.get('NODE_RANK', 0))
return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
num_samples=epoch_len, replacement=sampler_replace),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
@dataset.command
def get_base_eval_set(eval_hdf5, variable_eval=None):
if variable_eval:
print("Variable length eval!!")
ds = AudioSetDataset(eval_hdf5, clip_length=None)
else:
ds = AudioSetDataset(eval_hdf5)
return ds
@dataset.command
def get_base_valid_set(valid_hdf5, variable_eval=None):
if variable_eval:
print("Variable length valid_set !!")
ds = AudioSetDataset(valid_hdf5, clip_length=None)
else:
ds = AudioSetDataset(valid_hdf5)
return ds
@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
print("rolling...")
def roll_func(b):
x, i, y = b
x = torch.as_tensor(x)
sf = shift
if shift is None:
sf = int(np.random.random_integers(-shift_range, shift_range))
global FirstTime
return x.roll(sf, axis), i, y
return roll_func
@dataset.command
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_valid_set(normalize):
ds = get_base_valid_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def get_eval_set(normalize):
ds = get_base_eval_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def print_conf(_config):
print("Config of ", dataset.path, id(dataset))
print(_config)
print()
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset,
num_replicas=None,
rank=None,
shuffle: bool = True):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch == 0:
print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n")
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
if __name__ == "__main__":
from sacred import Experiment
ex = Experiment("test_dataset", ingredients=[dataset])
@ex.automain
def default_command():
ex.current_run.get_command_function("print_config")()
get_base_training_set()
ds = get_test_set()
print(ds[0])
ds = get_training_set()
print(ds[0])
print("get_base_training_set", len(get_base_training_set()))
print("get_base_test_set", len(get_base_test_set()))
print("get_training_set", len(get_training_set()))
print("get_test_set", len(get_test_set()))
================================================
FILE: fsd50k/prepare_scripts/convert_to_mp3.py
================================================
import multiprocessing
import glob
import os
import sys
if len(sys.argv) > 1:
FSD50K_base = sys.argv[1] # the path to of FSD50K base as downloaded from zalando.
else:
FSD50K_base = "/home/khaled/shared/FSD50K/" # the path to of FSD50K base as downloaded from zalando.
print("Pass the path to FSD50K: python convert_to_mp3.py path/to/fsd50k")
outputp = FSD50K_base + "/mp3/" # the path to the output mp3.
all_num = 0
def process_folder(fol="balanced_train_segments"):
print("now working on ", fol)
os.makedirs(outputp + fol, exist_ok=True)
all_files = list(glob.glob(FSD50K_base + fol + "/*.wav"))
print(f"it has {len(all_files)}")
global all_num
all_num = len(all_files)
cmds = [(i, file, outputp + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]
print(cmds[0])
with multiprocessing.Pool(processes=20) as pool:
pool.starmap(process_one, cmds)
def process_one(i, f1, f2):
if i % 100 == 0:
print(f"{i}/{all_num} \t", f1)
os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3")
print("We will convert the following folders to mp3: ")
folders = ['FSD50K.eval_audio', 'FSD50K.dev_audio']
print(folders)
for fol in folders:
process_folder(fol)
================================================
FILE: fsd50k/prepare_scripts/create_h5pymp3_dataset.py
================================================
# %%
import sys
import h5py
import pandas as pd
import numpy as np
import csv
import os
# %%
from numpy import dtype
if len(sys.argv) > 1:
FSD50K_base = sys.argv[1] # the path to of FSD50K base as downloaded from zalando.
else:
FSD50K_base = "/home/khaled/shared/FSD50K/" # the path to of FSD50K base as downloaded from zalando.
print("Pass the path to FSD50K: python convert_to_mp3.py path/to/fsd50k")
base_dir = "../../audioset_hdf5s/" # the path to store hdf file.
####
balanced_csv = FSD50K_base + "FSD50K.ground_truth/dev.csv"
eval_csv = FSD50K_base + "FSD50K.ground_truth/eval.csv"
class_idx_csv = FSD50K_base + "FSD50K.ground_truth/vocabulary.csv"
mp3_path = "/home/khaled/shared/FSD50K/mp3/"
# %%
df = pd.read_csv(class_idx_csv, header=None, index_col=0)
classes_list = list(df[1].values)
assert sorted(classes_list) == classes_list
id_to_ix = {id: i for i, id in enumerate(classes_list)}
ix_to_id = {i: id for i, id in enumerate(classes_list)}
# %%
# Load labels
df = pd.read_csv(balanced_csv)
train = df[df.split == "train"]
val = df[df.split == "val"]
eval = pd.read_csv(eval_csv)
# %%
def get_labels(df):
y = np.zeros((len(df), 200), dtype=np.int32)
for i, target in enumerate(df.labels.values):
for t in target.split(","):
y[i, id_to_ix[t]] = 1
return df.fname.values, y
# %%
for set_name, df, prefix in [("train", train, "FSD50K.dev_audio/"), ("val", val, "FSD50K.dev_audio/"),
("eval", eval, "FSD50K.eval_audio/")]:
print("now working on ", set_name, prefix, "len=", len(df))
# files, y = torch.load(read_file+".pth")
files, y = get_labels(df)
y = np.packbits(y, axis=-1)
packed_len = y.shape[1]
print(files[0], "classes: ", packed_len, y.dtype)
available_size = len(files)
dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = "FSD50K." + set_name
if os.path.isfile(base_dir + "mp3/" + save_file + "_mp3.hdf"):
print(base_dir + "mp3/" + save_file + "_mp3.hdf", "exists!\n\n\n contiue")
continue
with h5py.File(base_dir + "mp3/" + save_file + "_mp3.hdf", 'w') as hf:
audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
for i, file in enumerate(files):
if i % 1000 == 0:
print(f"{i}/{available_size}")
f = f"{file}.mp3"
a = np.fromfile(mp3_path + prefix + f, dtype='uint8')
audio_name[i] = f
waveform[i] = a
target[i] = y[i]
print(a.shape)
print("Done!", prefix)
================================================
FILE: helpers/audiodatasets.py
================================================
import hashlib
import os
import time
import torch
from torch.utils.data import Dataset
from os.path import expanduser
import logging
def h6(w):
return hashlib.md5(w.encode('utf-8')).hexdigest()[:6]
class AudioPreprocessDataset(Dataset):
"""A bases preprocessing dataset representing a Dataset of files that are loaded and preprossessed on the fly.
Access elements via __getitem__ to return: preprocessor(x),sample_id,label
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __init__(self, files, labels, label_encoder, base_dir, preprocessor, return_tensor=True, ordered_ids=None):
self.files = files
if ordered_ids is None:
ordered_ids = files
else:
print("AudioPreprocessDataset: ordered_ids is not None using it instead of files !!!")
self.ordered_ids = ordered_ids
self.labels = labels
self.label_encoder = label_encoder
self.base_dir = base_dir
self.preprocessor = preprocessor
self.return_tensor = return_tensor
def __getitem__(self, index):
x = self.preprocessor(self.base_dir + self.files[index])
if self.return_tensor and not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
return x, self.ordered_ids[index], self.labels[index]
def get_ordered_ids(self):
return self.ordered_ids
def get_ordered_labels(self):
return self.labels
def __len__(self):
return len(self.ordered_ids)
class ObjectCacher:
def __init__(self, get_obj_func, dataset_name, obj_name="",
cache_path="~/shared/kofta_cached_datasets/", verbose=True):
self.dataset_name = dataset_name
self.obj_name = obj_name
cache_path = expanduser(cache_path)
self.cache_path = os.path.join(cache_path, dataset_name)
try:
startTime = time.time()
xpath = self.get_obj_cache_path()
if verbose:
logging.info(
"attempting to load x from cache at " + xpath + "...")
self.obj = torch.load(xpath)
if verbose:
endTime = time.time()
logging.info(
"loaded " + xpath + " from cache in %s " % (endTime - startTime))
except IOError:
if verbose:
logging.info(
"Invalid cache " + xpath + " , recomputing")
self.obj = get_obj_func()
saveStartTime = time.time()
dirpath=os.path.dirname(xpath)
try:
original_umask = os.umask(0)
os.makedirs(dirpath, exist_ok=True)
finally:
os.umask(original_umask)
torch.save(self.obj, xpath)
if verbose:
endTime = time.time()
logging.info(
"loaded " + obj_name + " in %s, and cached in %s, total %s seconds " % (
(saveStartTime - startTime),
(endTime - saveStartTime), (endTime - startTime)))
def get_obj_cache_path(self):
return os.path.join(self.cache_path, self.obj_name + "_obj.pt")
def get(self):
return self.obj
class PreprocessDataset(Dataset):
"""A bases preprocessing dataset representing a preprocessing step of a Dataset preprossessed on the fly.
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __init__(self, dataset, preprocessor):
self.dataset = dataset
if not callable(preprocessor):
print("preprocessor: ", preprocessor)
raise ValueError('preprocessor should be callable')
self.preprocessor = preprocessor
def __getitem__(self, index):
return self.preprocessor(self.dataset[index])
def __len__(self):
return len(self.dataset)
class FilesCachedDataset(Dataset):
def __init__(self, get_dataset_func, dataset_name, x_name="",
cache_path="~/shared/kofta_cached_datasets/",
):
"""
Cached the dataset in small torch.save files (1 file per sample).
The dataset is suitable for SSDs being used bcache from a slow harddrive with a small
@param get_dataset_func: fuction gets called if the file cache is invalid
@param dataset_name: the folder containing the dataset
@param x_name: tag for the version
@param cache_path: cache_path
"""
self.dataset = None
def getDataset():
if self.dataset == None:
self.dataset = get_dataset_func()
return self.dataset
self.get_dataset_func = getDataset
self.x_name = x_name
cache_path = expanduser(cache_path)
self.cache_path = os.path.join(cache_path, dataset_name, "files_cache", self.x_name)
try:
original_umask = os.umask(0)
os.makedirs(self.cache_path, exist_ok=True)
finally:
os.umask(original_umask)
def __getitem__(self, index):
cpath = os.path.join(self.cache_path, str(index) + ".pt")
try:
return torch.load(cpath)
except FileNotFoundError:
tup = self.get_dataset_func()[index]
torch.save(tup, cpath)
return tup
def get_ordered_labels(self):
return self.get_dataset_func().get_ordered_labels()
def get_ordered_ids(self):
return self.get_dataset_func().get_ordered_ids()
def get_xcache_path(self):
return os.path.join(self.cache_path, self.x_name + "_x.pt")
def get_ycache_path(self):
return os.path.join(self.cache_path, self.y_name + "_y.pt")
def get_sidcache_path(self):
return os.path.join(self.cache_path, self.y_name + "_sid.pt")
def __len__(self):
return len(self.get_dataset_func())
class SelectionDataset(Dataset):
"""A dataset that selects a subsample from a dataset based on a set of sample ids.
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __init__(self, dataset, sample_ids):
self.available_indexes = []
self.dataset = dataset
self.reselect(sample_ids)
self.sample_ids = sample_ids
def reselect(self, sample_ids):
reverse_dict = dict([(sid, i) for i, sid in enumerate(self.dataset.get_ordered_ids())])
self.available_indexes = [reverse_dict[sid] for sid in sample_ids]
def get_ordered_ids(self):
return self.sample_ids
def get_ordered_labels(self):
labels=self.dataset.get_ordered_labels()
return [labels[i] for i in self.available_indexes]
#raise NotImplementedError("Maybe reconsider caching only a selection Dataset. why not select after cache?")
def __getitem__(self, index):
return self.dataset[self.available_indexes[index]]
def __len__(self):
return len(self.available_indexes)
class SimpleSelectionDataset(Dataset):
"""A dataset that selects a subsample from a dataset based on a set of sample ids.
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __init__(self, dataset, available_indexes ):
self.available_indexes = available_indexes
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[self.available_indexes[index]]
def __len__(self):
return len(self.available_indexes)
================================================
FILE: helpers/mixup.py
================================================
import numpy as np
import torch
def my_mixup(size, alpha):
rn_indices = torch.randperm(size)
lambd = np.random.beta(alpha, alpha, size).astype(np.float32)
lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1)
lam = torch.FloatTensor(lambd)
# data = data * lam + data2 * (1 - lam)
# targets = targets * lam + targets2 * (1 - lam)
return rn_indices, lam
================================================
FILE: helpers/models_size.py
================================================
def count_non_zero_params(model):
sum_params = 0
sum_non_zero = 0
desc = ""
def calc_params(model):
nonlocal desc, sum_params, sum_non_zero
skip = ""
if "batchnorm" in type(model).__name__.lower():
for k,p in [("running_mean", model.running_mean), ("running_var", model.running_var)]:
nonzero = p[p != 0].numel()
total = p.numel()
desc += f"type {type(model).__name__}, {k}, {total}, {nonzero}, {p.dtype}, {skip} " + "\n"
if skip != "skip":
sum_params += total
sum_non_zero += nonzero
for k, p in model.named_parameters(recurse=False):
nonzero = p[p != 0].numel()
total = p.numel()
desc += f"type {type(model).__name__}, {k}, {total}, {nonzero}, {p.dtype}, {skip} " + "\n"
if skip != "skip":
sum_params += total
sum_non_zero += nonzero
model.apply(calc_params)
return desc, sum_params, sum_non_zero
================================================
FILE: helpers/ramp.py
================================================
import numpy as np
from ba3l.ingredients.ingredient import Ingredient
# credit: https://github.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorch/blob/master/utils/ramps.py
def pseudo_rampup(T1, T2):
def warpper(epoch):
if epoch > T1:
alpha = (epoch - T1) / (T2 - T1)
if epoch > T2:
alpha = 1.0
else:
alpha = 0.0
return alpha
return warpper
def exp_rampup(rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
def warpper(epoch):
if epoch < rampup_length:
epoch = np.clip(epoch, 0.5, rampup_length)
phase = 1.0 - epoch / rampup_length
return float(np.exp(-5.0 * phase * phase))
else:
return 1.0
return warpper
def linear_rampup(rampup_length):
"""Linear rampup"""
def warpper(epoch):
if epoch < rampup_length:
return epoch / rampup_length
else:
return 1.0
return warpper
def linear_rampdown(rampdown_length, start=0, last_value=0):
"""Linear rampup -(start)- (rampdown_length) \ _(for the rest) """
def warpper(epoch):
if epoch <= start:
return 1.
elif epoch - start < rampdown_length:
return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length
else:
return last_value
return warpper
def exp_rampdown(rampdown_length, num_epochs):
"""Exponential rampdown from https://arxiv.org/abs/1610.02242"""
def warpper(epoch):
if epoch >= (num_epochs - rampdown_length):
ep = .5 * (epoch - (num_epochs - rampdown_length))
return float(np.exp(-(ep * ep) / rampdown_length))
else:
return 1.0
return warpper
def cosine_rampdown(rampdown_length, num_epochs):
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
def warpper(epoch):
if epoch >= (num_epochs - rampdown_length):
ep = .5 * (epoch - (num_epochs - rampdown_length))
return float(.5 * (np.cos(np.pi * ep / rampdown_length) + 1))
else:
return 1.0
return warpper
def exp_warmup(rampup_length, rampdown_length, num_epochs):
rampup = exp_rampup(rampup_length)
rampdown = exp_rampdown(rampdown_length, num_epochs)
def warpper(epoch):
return rampup(epoch) * rampdown(epoch)
return warpper
def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value):
rampup = exp_rampup(warmup)
rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value)
def warpper(epoch):
return rampup(epoch) * rampdown(epoch)
return warpper
def test_warmup():
warmup = exp_warmup(20, 100, 150)
for ep in range(500):
print(warmup(ep))
def test_warmupl():
warmup = exp_warmup_linear_down(20, 100, 50, 0.001)
for ep in range(500):
print(warmup(ep))
def cosine_cycle(cycle_len=20,ramp_down_start=100,last_lr_value=0.01):
"""Cosine rampdown from https://arxiv.org/abs/1608.03983"""
ramp_down_start = cycle_len+ (ramp_down_start-1)//cycle_len*(cycle_len)
print("adjusted ramp_down_start:",ramp_down_start)
def warpper(epoch):
ep = (epoch+cycle_len//2.)/(1.*cycle_len)
if epoch>ramp_down_start:
return last_lr_value
return float(last_lr_value + (1.-last_lr_value)* .5 * (np.cos(2.*np.pi * ep) + 1))
return warpper
if __name__ == '__main__':
test= exp_warmup_linear_down(20, 100, 50, 150)
for i in range(250):
print(test(i))
================================================
FILE: helpers/swa_callback.py
================================================
# Adapted from PyTorch Lightning so that it only does the averaging
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Stochastic Weight Averaging Callback
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Callable, Optional, Union
import torch
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.optim.swa_utils import SWALR
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_epoch_start: Union[int, float] = 0.8,
swa_freq: Union[int, float] = 3,
swa_lrs: Optional[Union[float, list]] = None,
annealing_epochs: int = 10,
annealing_strategy: str = "cos",
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = None,
):
r"""
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to
Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).
This documentation is highly inspired by PyTorch's work on SWA.
The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package.
For a SWA explanation, please take a look
`here `_.
.. warning:: ``StochasticWeightAveraging`` is in beta and subject to change.
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
SWA can easily be activated directly from the Trainer as follow:
.. code-block:: python
Trainer(stochastic_weight_avg=True)
Arguments:
swa_epoch_start: If provided as int, the procedure will start from
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch
swa_lrs: the learning rate value for all param groups together or separately for each group.
annealing_epochs: number of epochs in the annealing phase (default: 10)
annealing_strategy: Specifies the annealing strategy (default: "cos"):
- ``"cos"``. For cosine annealing.
- ``"linear"`` For linear annealing
avg_fn: the averaging function used to update the parameters;
the function must take in the current value of the
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: ``None``)
device: if provided, the averaged model will be stored on the ``device``.
When None is provided, it will infer the `device` from ``pl_module``.
(default: ``"cpu"``)
"""
err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:
raise MisconfigurationException(err_msg)
if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
raise MisconfigurationException(err_msg)
wrong_type = not isinstance(swa_lrs, (float, list))
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)):
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")
if avg_fn is not None and not isinstance(avg_fn, Callable):
raise MisconfigurationException("The `avg_fn` should be callable.")
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
self.swa_freq = swa_freq
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm = None
self._average_model = None
@property
def swa_start(self) -> int:
return max(self._swa_epoch_start - 1, 0) # 0-based
@property
def swa_end(self) -> int:
return self._max_epochs - 1 # 0-based
@staticmethod
def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage:str):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module.net)
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if len(trainer.optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)
self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
print("\n\n_model_contains_batch_norm\n\n")
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.max_epochs += 1
def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if trainer.current_epoch == self.swa_start:
print(f"\n\n SWA START at {trainer.current_epoch}\n\n")
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)
optimizers = trainer.optimizers
for param_group in optimizers[0].param_groups:
if self._swa_lrs is None:
initial_lr = param_group["lr"]
elif isinstance(self._swa_lrs, float):
initial_lr = self._swa_lrs
else:
initial_lr = self._swa_lrs[0]
param_group["initial_lr"] = initial_lr
self._swa_lrs = initial_lr
self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=initial_lr,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
pl_module.net_swa = self._average_model
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):
self.update_parameters(self._average_model, pl_module.net, self.n_averaged, self.avg_fn)
pl_module.net_swa = self._average_model
def on_train_epoch_end(self, trainer: 'pl.Trainer',pl_module: 'pl.LightningModule', *args):
trainer.fit_loop._skip_backward = False
def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
pass
def on_validation_epoch_start(self, trainer, pl_module) -> None:
"""Called when the val epoch begins."""
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):
pl_module.do_swa= True
else:
pl_module.do_swa = False
@staticmethod
def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'):
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))
def reset_batch_norm_and_save_state(self, pl_module):
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154
"""
self.momenta = {}
for module in pl_module.modules():
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
continue
module.running_mean = torch.zeros_like(
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
)
module.running_var = torch.ones_like(
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
)
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0
def reset_momenta(self):
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165
"""
for bn_module in self.momenta.keys():
bn_module.momentum = self.momenta[bn_module]
@staticmethod
def update_parameters(
average_model, model, n_averaged: torch.LongTensor, avg_fn: _AVG_FN
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112
"""
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
p_swa_ = p_swa.detach()
p_model_ = p_model.detach().to(device)
src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))
p_swa_.copy_(src)
n_averaged += 1
@staticmethod
def avg_fn(
averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
) -> torch.FloatTensor:
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97
"""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
================================================
FILE: helpers/swa_legacy.py
================================================
# Adapted from PyTorch Lightning so that it only does the averaging
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Stochastic Weight Averaging Callback
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Callable, Optional, Union
import torch
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.optim.swa_utils import SWALR
_AVG_FN = Callable[[torch.Tensor, torch.Tensor,
torch.LongTensor], torch.FloatTensor]
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_epoch_start: Union[int, float] = 0.8,
swa_freq: Union[int, float] = 3,
swa_lrs: Optional[Union[float, list]] = None,
annealing_epochs: int = 10,
annealing_strategy: str = "cos",
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = None,
):
r"""
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to
Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).
This documentation is highly inspired by PyTorch's work on SWA.
The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package.
For a SWA explanation, please take a look
`here `_.
.. warning:: ``StochasticWeightAveraging`` is in beta and subject to change.
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
SWA can easily be activated directly from the Trainer as follow:
.. code-block:: python
Trainer(stochastic_weight_avg=True)
Arguments:
swa_epoch_start: If provided as int, the procedure will start from
the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,
the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch
swa_lrs: the learning rate value for all param groups together or separately for each group.
annealing_epochs: number of epochs in the annealing phase (default: 10)
annealing_strategy: Specifies the annealing strategy (default: "cos"):
- ``"cos"``. For cosine annealing.
- ``"linear"`` For linear annealing
avg_fn: the averaging function used to update the parameters;
the function must take in the current value of the
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: ``None``)
device: if provided, the averaged model will be stored on the ``device``.
When None is provided, it will infer the `device` from ``pl_module``.
(default: ``"cpu"``)
"""
err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:
raise MisconfigurationException(err_msg)
if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
raise MisconfigurationException(err_msg)
wrong_type = not isinstance(swa_lrs, (float, list))
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
wrong_list = isinstance(swa_lrs, list) and not all(
lr > 0 and isinstance(lr, float) for lr in swa_lrs)
if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)):
raise MisconfigurationException(
"The `swa_lrs` should be a positive float or a list of positive float.")
if avg_fn is not None and not isinstance(avg_fn, Callable):
raise MisconfigurationException("The `avg_fn` should be callable.")
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(
f"device is expected to be a torch.device or a str. Found {device}")
self.swa_freq = swa_freq
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm = None
self._average_model = None
@property
def swa_start(self) -> int:
return max(self._swa_epoch_start - 1, 0) # 0-based
@property
def swa_end(self) -> int:
return self._max_epochs - 1 # 0-based
@staticmethod
def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module.net)
def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers
if len(optimizers) != 1:
raise MisconfigurationException(
"SWA currently works with 1 `optimizer`.")
if len(lr_schedulers) > 1:
raise MisconfigurationException(
"SWA currently not supported for more than 1 `lr_scheduler`.")
if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(
trainer.max_epochs * self._swa_epoch_start)
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(
pl_module)
self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
print("\n\n_model_contains_batch_norm\n\n")
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.max_epochs += 1
def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
if trainer.current_epoch == self.swa_start:
print(f"\n\n SWA START at {trainer.current_epoch}\n\n")
# move average model to request device.
self._average_model = self._average_model.to(
self._device or pl_module.device)
optimizers = trainer.optimizers
for param_group in optimizers[0].param_groups:
if self._swa_lrs is None:
initial_lr = param_group["lr"]
elif isinstance(self._swa_lrs, float):
initial_lr = self._swa_lrs
else:
initial_lr = self._swa_lrs[0]
param_group["initial_lr"] = initial_lr
self._swa_lrs = initial_lr
self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=initial_lr,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)
self.n_averaged = torch.tensor(
0, dtype=torch.long, device=pl_module.device)
pl_module.net_swa = self._average_model
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):
self.update_parameters(self._average_model,
pl_module.net, self.n_averaged, self.avg_fn)
pl_module.net_swa = self._average_model
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', *args):
trainer.train_loop._skip_backward = False
def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
pass
def on_validation_epoch_start(self, trainer, pl_module) -> None:
"""Called when the val epoch begins."""
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):
pl_module.do_swa = True
else:
pl_module.do_swa = False
@staticmethod
def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'):
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))
def reset_batch_norm_and_save_state(self, pl_module):
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154
"""
self.momenta = {}
for module in pl_module.modules():
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
continue
module.running_mean = torch.zeros_like(
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
)
module.running_var = torch.ones_like(
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
)
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0
def reset_momenta(self):
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165
"""
for bn_module in self.momenta.keys():
bn_module.momentum = self.momenta[bn_module]
@staticmethod
def update_parameters(
average_model, model, n_averaged: torch.LongTensor, avg_fn: _AVG_FN
):
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112
"""
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
device = p_swa.device
p_swa_ = p_swa.detach()
p_model_ = p_model.detach().to(device)
src = p_model_ if n_averaged == 0 else avg_fn(
p_swa_, p_model_, n_averaged.to(device))
p_swa_.copy_(src)
n_averaged += 1
@staticmethod
def avg_fn(
averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor
) -> torch.FloatTensor:
"""
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97
"""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
================================================
FILE: helpers/workersinit.py
================================================
import torch
import numpy as np
import random
def worker_init_fn(x):
seed = (torch.initial_seed() + x * 1000) % 2 ** 31 # problem with nearly seeded randoms
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
return
================================================
FILE: models/helpers/vit_helpers.py
================================================
"""
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
Credit to @leo19941227 for remove timm dependencies here : https://github.com/s3prl/passt_hear21/blob/48a0dc1b824641ca59884ced53f5b86053fed141/hear21passt/models/helpers/vit_helpers.py
"""
import math
import logging
import warnings
from copy import deepcopy
import torch
from torch import nn
try:
from timm.models._hub import download_cached_file
except ModuleNotFoundError:
from timm.models.hub import download_cached_file
# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = True
_CHECK_HASH = False
_logger = logging.getLogger(__name__)
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = (
conv_weight.float()
) # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError("Weight format not supported by conversion.")
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= 3 / float(in_chans)
conv_weight = conv_weight.to(conv_type)
return conv_weight
def load_pretrained(
model,
default_cfg=None,
num_classes=1000,
in_chans=3,
filter_fn=None,
strict=True,
progress=False,
):
"""Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
num_classes (int): num_classes for model
in_chans (int): in_chans for model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
progress (bool): enable progress bar for weight download
"""
default_cfg = default_cfg or getattr(model, "default_cfg", None) or {}
pretrained_url = default_cfg.get("url", None)
if not pretrained_url:
_logger.warning(
"No pretrained weights exist for this model. Using random initialization."
)
return
_logger.info(f"Loading pretrained weights from url ({pretrained_url})")
pretrained_loc = download_cached_file(
pretrained_url,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS,
)
state_dict = torch.load(pretrained_loc, map_location="cpu")
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
input_convs = default_cfg.get("first_conv", None)
if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + ".weight"
try:
state_dict[weight_name] = adapt_input_conv(
in_chans, state_dict[weight_name]
)
_logger.info(
f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)"
)
except NotImplementedError as e:
del state_dict[weight_name]
strict = False
_logger.warning(
f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer."
)
classifiers = default_cfg.get("classifier", None)
label_offset = default_cfg.get("label_offset", 0)
if classifiers is not None:
if isinstance(classifiers, str):
classifiers = (classifiers,)
if num_classes != default_cfg["num_classes"]:
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + ".weight"]
del state_dict[classifier_name + ".bias"]
strict = False
elif label_offset > 0:
for classifier_name in classifiers:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + ".weight"]
state_dict[classifier_name + ".weight"] = classifier_weight[
label_offset:
]
classifier_bias = state_dict[classifier_name + ".bias"]
state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)
def overlay_external_default_cfg(default_cfg, kwargs):
"""Overlay 'external_default_cfg' in kwargs on top of default_cfg arg."""
external_default_cfg = kwargs.pop("external_default_cfg", None)
if external_default_cfg:
default_cfg.pop("url", None) # url should come from external cfg
default_cfg.pop("hf_hub", None) # hf hub id should come from external cfg
default_cfg.update(external_default_cfg)
def filter_kwargs(kwargs, names):
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)
def set_default_kwargs(kwargs, names, default_cfg):
for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# default_cfg has one input_size=(C, H ,W) entry
if n == "img_size":
input_size = default_cfg.get("input_size", None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == "in_chans":
input_size = default_cfg.get("input_size", None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = default_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, default_cfg[n])
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
"""Update the default_cfg and kwargs before passing to model
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
could/should be replaced by an improved configuration mechanism
Args:
default_cfg: input default_cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
default_kwarg_names = ("num_classes", "global_pool", "in_chans")
if default_cfg.get("fixed_input_size", False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ("img_size",)
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def build_model_with_cfg(
model_cls,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg=None,
feature_cfg=None,
pretrained_strict: bool = True,
pretrained_filter_fn=None,
pretrained_custom_load=False,
kwargs_filter=None,
**kwargs,
):
"""Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including:
* handling default_cfg and associated pretained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
default_cfg (dict): model's default pretrained/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop("pruned", False)
features = False
feature_cfg = feature_cfg or {}
default_cfg = deepcopy(default_cfg) if default_cfg else {}
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
default_cfg.setdefault("architecture", variant)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop("features_only", False):
features = True
feature_cfg.setdefault("out_indices", (0, 1, 2, 3, 4))
if "out_indices" in kwargs:
feature_cfg["out_indices"] = kwargs.pop("out_indices")
# Build the model
model = (
model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
)
model.default_cfg = default_cfg
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = (
0
if features
else getattr(model, "num_classes", kwargs.get("num_classes", 1000))
)
if pretrained:
assert not pretrained_custom_load, "URL should not contain npz for PASST models"
load_pretrained(
model,
num_classes=num_classes_pretrained,
in_chans=kwargs.get("in_chans", 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
return model
================================================
FILE: models/passt.py
================================================
"""
Most of this code comes from the timm library.
We tried to disentangle from the timm library version.
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import math
import logging
import warnings
from functools import partial
import collections
from collections import OrderedDict
from copy import deepcopy
from itertools import repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers.vit_helpers import update_default_cfg_and_kwargs, DropPath, trunc_normal_, build_model_with_cfg
_logger = logging.getLogger()
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_tiny_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_base_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_large_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_sam_224': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
'vit_base_patch16_sam_224': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
),
'vit_base_patch16_224_miil': _cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
'/vit_base_patch16_224_1k_miil_84_4.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
),
# PaSST
'passt_s_swa_p16_128_ap476': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_kd_p16_128_ap486': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_l_kd_p16_128_ap47': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_128_ap4761': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_128_ap472': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.472.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s16_128_ap468': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_s16_128_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_s14_128_ap471': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.471-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s14_128_ap469': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.469.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_s12_128_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s12_128_ap470': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.470.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_f128_stfthop100_p16_s10_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop100-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_f128_stfthop160_p16_s10_ap473': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-20sec-p16-s10-ap.474-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-30sec-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=20),
'openmic2008_passt_u_f128_p16_s10_ap85 ': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=20),
}
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3:
raise NotImplementedError('Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
first_RUN = True
PLUS1_TRICK = False
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None,
flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.stride = stride
self.grid_size = (img_size[0] // stride[0], img_size[1] // stride[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
if not (H == self.img_size[0] and W == self.img_size[1]):
warnings.warn(f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
# to do maybe replace weights
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
if first_RUN: print("self.norm(x)", x.size())
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if PLUS1_TRICK:
# +1 trick
attn = torch.cat([attn, torch.zeros(attn.shape[:-1]+(1,), dtype=attn.dtype, device=attn.device)], dim=-1)
attn = attn.softmax(dim=-1)
if PLUS1_TRICK:
# +1 trick
attn = attn[...,:-1]
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PaSST(nn.Module):
"""
Based on the implementation of Vision Transformer in timm library.
Take a look at the get_model function, adapting the weights of pretrained imagenet models.
"""
def __init__(self, u_patchout=0, s_patchout_t=0, s_patchout_f=0, img_size=(128, 998), patch_size=16, stride=16,
in_chans=1, num_classes=527, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init=''):
"""
Args:
u_patchout: Unstructured Patchout integer, number of items to be removed from the final sequence
s_patchout_t: structured Patchout time integer, number of columns to be removed from the patches grid
s_patchout_f: structured Patchout Frequency integer, number of rows to be removed from the patches grid
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
weight_init: (str): weight init scheme
"""
super().__init__()
self.num_classes = num_classes
self.u_patchout = u_patchout
self.s_patchout_t = s_patchout_t
self.s_patchout_f = s_patchout_f
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, stride=stride, in_chans=in_chans, embed_dim=embed_dim,
flatten=False)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
# PaSST
# refer to https://arxiv.org/abs/2110.05069 Section 2
self.new_pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) # for C and D tokens
self.freq_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.patch_embed.grid_size[0], 1)) # | f
self.time_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, self.patch_embed.grid_size[1])) # __ t
####
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Sequential(nn.LayerNorm(self.num_features),
nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.new_pos_embed, std=.02)
trunc_normal_(self.freq_new_pos_embed, std=.02)
trunc_normal_(self.time_new_pos_embed, std=.02)
if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02)
if mode.startswith('jax'):
# leave cls token as zeros to match jax impl
raise RuntimeError("Not supported yet")
else:
trunc_normal_(self.cls_token, std=.02)
self.apply(_init_vit_weights)
def _init_weights(self, m):
# this fn left here for compat with downstream users
_init_vit_weights(m)
@torch.jit.ignore
def no_weight_decay(self):
return {'new_pos_embed', 'freq_new_pos_embed', 'time_new_pos_embed', 'cls_token', 'dist_token'}
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
global first_RUN # not jit friendly? use trace instead
x = self.patch_embed(x) # [b, e, f, t]
B_dim, E_dim, F_dim, T_dim = x.shape # slow
if first_RUN: print(" patch_embed : ", x.shape)
# Adding Time/Freq information
if first_RUN: print(" self.time_new_pos_embed.shape", self.time_new_pos_embed.shape)
time_new_pos_embed = self.time_new_pos_embed
if x.shape[-1] < time_new_pos_embed.shape[-1]:
if self.training:
toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item()
if first_RUN: print(f" CUT with randomoffset={toffset} time_new_pos_embed.shape",
time_new_pos_embed.shape)
time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]]
else:
time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape)
else:
warnings.warn(
f"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut")
x = x[:, :, :, :time_new_pos_embed.shape[-1]]
x = x + time_new_pos_embed
if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape)
x = x + self.freq_new_pos_embed
# Structured Patchout https://arxiv.org/abs/2110.05069 Section 2.2
if self.training and self.s_patchout_t:
if first_RUN: print(f"X Before time Patchout of {self.s_patchout_t} ", x.size())
# ([1, 768, 1, 82])
random_indices = torch.randperm(T_dim)[:T_dim - self.s_patchout_t].sort().values
x = x[:, :, :, random_indices]
if first_RUN: print("X after time Patchout", x.size())
if self.training and self.s_patchout_f:
if first_RUN: print(f"X Before Freq Patchout of {self.s_patchout_f} ", x.size())
# [1, 768, 12, 1]
random_indices = torch.randperm(F_dim)[:F_dim - self.s_patchout_f].sort().values
x = x[:, :, random_indices, :]
if first_RUN: print(" \n X after freq Patchout: ", x.size())
###
# Flatten the sequence
x = x.flatten(2).transpose(1, 2)
# Unstructured Patchout
if first_RUN: print("X flattened", x.size())
if self.training and self.u_patchout:
seq_len = x.shape[1]
random_indices = torch.randperm(seq_len)[:seq_len - self.u_patchout].sort().values
x = x[:, random_indices, :]
if first_RUN: print("X After Unstructured Patchout", x.size())
####
# Add the C/D tokens
if first_RUN: print(" self.new_pos_embed.shape", self.new_pos_embed.shape)
cls_tokens = self.cls_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, :1, :]
if first_RUN: print(" self.cls_tokens.shape", cls_tokens.shape)
if self.dist_token is None:
x = torch.cat((cls_tokens, x), dim=1)
else:
dist_token = self.dist_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, 1:, :]
if first_RUN: print(" self.dist_token.shape", dist_token.shape)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
if first_RUN: print(" final sequence x", x.shape)
x = self.pos_drop(x)
x = self.blocks(x)
if first_RUN: print(f" after {len(self.blocks)} atten blocks x", x.shape)
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
global first_RUN
if first_RUN: print("x", x.size())
x = self.forward_features(x)
if self.head_dist is not None:
features = (x[0] + x[1]) / 2
if first_RUN: print("forward_features", features.size())
x = self.head(features)
if first_RUN: print("head", x.size())
first_RUN = False
return x, features
else:
features = x
if first_RUN: print("forward_features", features.size())
x = self.head(x)
if first_RUN: print("head", x.size())
first_RUN = False
return x, features
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
if jax_impl:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
else:
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='bicubic'):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, posemb_new.shape,
num_tokens)
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= num_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, gs_new=(), mode='bicubic'):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, gs_new,
num_tokens)
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)
freq_new_pos_embed = posemb_grid.mean(dim=3, keepdim=True)
time_new_pos_embed = posemb_grid.mean(dim=2, keepdim=True)
_logger.info('New Position cls/dstl embedding %s', posemb_tok.shape)
_logger.info('New FREQ Position embedding %s', freq_new_pos_embed.shape)
_logger.info('New TIME Position embedding %s', time_new_pos_embed.shape)
return posemb_tok, freq_new_pos_embed, time_new_pos_embed
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
state_dict = {k: v for k, v in state_dict.items()}
if "time_new_pos_embed" not in state_dict:
# we are working with ImageNet model
_logger.info("Adapting pos embedding from ImageNet pretrained model to PaSST.")
v = state_dict.pop("pos_embed")
new_pos_embed, freq_new_pos_embed, time_new_pos_embed = adapt_image_pos_embed_to_passt(
v, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
state_dict["new_pos_embed"] = new_pos_embed
state_dict["freq_new_pos_embed"] = freq_new_pos_embed
state_dict["time_new_pos_embed"] = time_new_pos_embed
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# this should never occur
v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
out_dict[k] = v
return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
# NOTE this extra code to support handling of repr size for in21k pretrained models
default_num_classes = default_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
model = build_model_with_cfg(
PaSST, variant, pretrained,
default_cfg=default_cfg,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in default_cfg['url'],
**kwargs)
return model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
print("\n\n Loading DEIT BASE 384\n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=476 SWA \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768,
depth=7, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=4763 SWA \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_128_ap4761', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_128_ap472(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_128_ap472', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (12, 12):
warnings.warn(
f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_s12_128_ap470', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
print("\n\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
print("\n\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (12, 12):
warnings.warn(
f"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_s12_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (14, 14):
warnings.warn(
f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_s14_128_ap469', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (14, 14):
warnings.warn(
f"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_s14_128_ap471', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (16, 16):
warnings.warn(
f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_swa_p16_s16_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)
return model
def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (16, 16):
warnings.warn(
f"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_p16_s16_128_ap468', pretrained=pretrained, distilled=True, **model_kwargs)
return model
from ba3l.ingredients.ingredient import Ingredient
model_ing = Ingredient("passt")
model_ing.add_config(instance_cmd="get_model")
@model_ing.command
def fix_embedding_layer(model, embed="default"):
if embed == "default":
return model
if embed == "overlap":
model.patch_embed = PatchEmbedAdaptiveMean(replace=model.patch_embed)
if embed == "am_keepconv":
model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed)
return model
@model_ing.command
def lighten_model(model, cut_depth=0):
if cut_depth == 0:
return model
if cut_depth:
if cut_depth < 0:
print(f"\n Reducing model depth by removing every {-cut_depth} layer \n\n")
else:
print(f"\n Reducing model depth by {cut_depth} \n\n")
if len(model.blocks) < cut_depth + 2:
raise ValueError(f"Cut depth a VIT with {len(model.blocks)} "
f"layers should be between 1 and {len(model.blocks) - 2}")
print(f"\n Before Cutting it was {len(model.blocks)} \n\n")
old_blocks = list(model.blocks.children())
if cut_depth < 0:
print(f"cut_depth={cut_depth}")
old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]]
else:
old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:]
model.blocks = nn.Sequential(*old_blocks)
print(f"\n Atfer Cutting it is {len(model.blocks)} \n\n")
return model
@model_ing.command
def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1, fstride=10,
tstride=10,
input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=0, s_patchout_f=0,
):
"""
:param arch: Base ViT or Deit architecture
:param pretrained: use pretrained model on imagenet
:param n_classes: number of classes
:param in_channels: number of input channels: 1 for mono
:param fstride: the patches stride over frequency.
:param tstride: the patches stride over time.
:param input_fdim: the expected input frequency bins.
:param input_tdim: the expected input time bins.
:param u_patchout: number of input patches to drop in Unstructured Patchout as defined in https://arxiv.org/abs/2110.05069
:param s_patchout_t: number of input time frames to drop Structured Patchout as defined in https://arxiv.org/abs/2110.05069
:param s_patchout_f: number of input frequency bins to drop Structured Patchout as defined in https://arxiv.org/abs/2110.05069
:param audioset_pretrain: use pretrained models on Audioset.
:return:
"""
model_func = None
input_size = (input_fdim, input_tdim)
stride = (fstride, tstride)
if arch == "passt_deit_bd_p16_384": # base deit
model_func = deit_base_distilled_patch16_384
elif arch == "passt_s_kd_p16_128_ap486": # pretrained
model_func = passt_s_kd_p16_128_ap486
elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L
model_func = passt_l_kd_p16_128_ap47
elif arch == "passt_s_swa_p16_128_ap476": # pretrained
model_func = passt_s_swa_p16_128_ap476
elif arch == "passt_s_swa_p16_128_ap4761":
model_func = passt_s_swa_p16_128_ap4761
elif arch == "passt_s_p16_128_ap472":
model_func = passt_s_p16_128_ap472
elif arch == "passt_s_p16_s16_128_ap468":
model_func = passt_s_p16_s16_128_ap468
elif arch == "passt_s_swa_p16_s16_128_ap473":
model_func = passt_s_swa_p16_s16_128_ap473
elif arch == "passt_s_swa_p16_s14_128_ap471":
model_func = passt_s_swa_p16_s14_128_ap471
elif arch == "passt_s_p16_s14_128_ap469":
model_func = passt_s_p16_s14_128_ap469
elif arch == "passt_s_swa_p16_s12_128_ap473":
model_func = passt_s_swa_p16_s12_128_ap473
elif arch == "passt_s_p16_s12_128_ap470":
model_func = passt_s_p16_s12_128_ap470
elif arch == "passt_s_f128_20sec_p16_s10_ap474":
model_func = passt_s_f128_20sec_p16_s10_ap474_swa
elif arch == "passt_s_f128_30sec_p16_s10_ap473":
model_func = passt_s_f128_30sec_p16_s10_ap473_swa
if model_func is None:
raise RuntimeError(f"Unknown model {arch}")
model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels,
img_size=input_size, stride=stride, u_patchout=u_patchout,
s_patchout_t=s_patchout_t, s_patchout_f=s_patchout_f)
model = fix_embedding_layer(model)
model = lighten_model(model)
print(model)
return model
class EnsembelerModel(nn.Module):
def __init__(self, models):
super(EnsembelerModel, self).__init__()
self.models = nn.ModuleList(models)
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
all_out = None
for i, m in enumerate(self.models):
out, _ = m(x)
if all_out is None:
all_out = out
else:
all_out = out + all_out
all_out = all_out / len(self.models)
return all_out, all_out
@model_ing.command
def get_ensemble_model(arch_list=[]):
# arch_list = [(passt_s_swa_p16_128_ap476,fstride,tstride)]
models_list = [get_model(arch=arch, fstride=fstride, tstride=tstride) for arch, fstride, tstride in arch_list]
model = EnsembelerModel(models_list)
print(model)
return model
================================================
FILE: models/preprocess.py
================================================
import torch.nn as nn
import torchaudio
from torch.nn.functional import conv1d, conv2d
import torch
from ba3l.ingredients.ingredient import Ingredient
model_ing = Ingredient("spectrograms")
sz_float = 4 # size of a float
epsilon = 10e-8 # fudge factor for normalization
@model_ing.command
class AugmentMelSTFT(nn.Module):
def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=1, fmax_aug_range=1000):
torch.nn.Module.__init__(self)
# adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
# Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast
self.win_length = win_length
self.n_mels = n_mels
self.n_fft = n_fft
self.sr = sr
self.htk = htk
self.fmin = fmin
if fmax is None:
fmax = sr // 2 - fmax_aug_range // 2
print(f"Warning: FMAX is None setting to {fmax} ")
self.fmax = fmax
self.norm = norm
self.hopsize = hopsize
self.register_buffer('window',
torch.hann_window(win_length, periodic=False),
persistent=False)
assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
assert fmin_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
self.fmin_aug_range = fmin_aug_range
self.fmax_aug_range = fmax_aug_range
self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
if freqm == 0:
self.freqm = torch.nn.Identity()
else:
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
if timem == 0:
self.timem = torch.nn.Identity()
else:
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
def forward(self, x):
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
center=True, normalized=False, window=self.window, return_complex=False)
x = (x ** 2).sum(dim=-1) # power mag
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
# don't augment eval data
if not self.training:
fmin = self.fmin
fmax = self.fmax
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr,
fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0)
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
device=x.device)
with torch.cuda.amp.autocast(enabled=False):
melspec = torch.matmul(mel_basis, x)
melspec = (melspec + 0.00001).log()
if self.training:
melspec = self.freqm(melspec)
melspec = self.timem(melspec)
melspec = (melspec + 4.5) / 5. # fast normalization
return melspec
def extra_repr(self):
return 'winsize={}, hopsize={}'.format(self.win_length,
self.hopsize
)
================================================
FILE: openmic/README.md
================================================
# Experiments on OpenMIC-2018
[OpenMIC-2018](https://github.com/cosmir/openmic-2018) ([zenodo](https://zenodo.org/record/1432913#.W6dPeJNKjOR)) is a dataset for polyphonic instruments identification.
## Preparing the dataset
Use `openmic2008/prepare_scripts/download_preprocess.py` to download and pack the dataset:
```shell
cd openmic2008/prepare_scripts/
python download_preprocess.py
```
When the script completes, you should have two files inside `audioset_hdf5s/mp3/`
`openmic_train.csv_mp3.hdf` and `openmic_test.csv_mp3.hdf`
these files contains the mp3s of the dataset and the labels.
## Fine-tuning pretrained PaSST on the openmic2008
Similar to audioset you can use:
```shell
# Example call with all the default config:
python ex_openmic.py with trainer.precision=16 -p
```
```shell
# with 2 gpus:
DDP=2 python ex_openmic.py with trainer.precision=16 -p
```
================================================
FILE: openmic/dataset.py
================================================
import io
import os
import pathlib
import random
import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler
import torch
from ba3l.ingredients.datasets import Dataset
import pandas as pd
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
from sklearn import preprocessing
from torch.utils.data import Dataset as TorchDataset
import numpy as np
import h5py
from helpers.audiodatasets import PreprocessDataset
LMODE = os.environ.get("LMODE", False)
dataset = Dataset('openMIC')
@dataset.config
def default_config():
name = 'openmic2008' # dataset name
normalize = False # normalize dataset
subsample = False # subsample squares from the dataset
roll = True # apply roll augmentation
fold = 1
base_dir = "audioset_hdf5s/" # base directory of the dataset as downloaded
if LMODE:
base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/"
openmic_train_hdf5 = base_dir + "mp3/openmic_train.csv_mp3.hdf"
openmic_test_hdf5 = base_dir + "mp3/openmic_test.csv_mp3.hdf"
ir_path = base_dir + "irs/"
num_of_classes = 20
def decode_mp3(mp3_arr):
"""
decodes an array if uint8 representing an mp3 file
:rtype: np.array
"""
container = av.open(io.BytesIO(mp3_arr.tobytes()))
stream = next(s for s in container.streams if s.type == 'audio')
# print(stream)
a = []
for i, packet in enumerate(container.demux(stream)):
for frame in packet.decode():
a.append(frame.to_ndarray().reshape(-1))
waveform = np.concatenate(a)
if waveform.dtype != 'float32':
raise RuntimeError("Unexpected wave type")
return waveform
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
else:
return x[0: audio_length]
irs_arr = None
@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
if not ir_augment:
return
global irs_arr
if irs_arr is None:
all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
if cut_irs_offset is not None:
all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
print("will use these IRs:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])
_run.info["ir_devices"] = all_paths_name
irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
return irs_arr[int(np.random.randint(0, len(irs_arr)))]
@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
if ir_augment and torch.rand(1) < ir_augment:
ir = get_ir_sample()
waveform = convolve(waveform, ir, 'full')
if gain_augment:
gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
amp = 10 ** (gain / 20)
waveform = waveform * amp
return waveform
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
x1, f1, y1 = self.dataset[index]
y1 = torch.as_tensor(y1)
if torch.rand(1) < self.rate:
idx2 = torch.randint(len(self.dataset), (1,)).item()
x2, f2, y2 = self.dataset[idx2]
y2 = torch.as_tensor(y2)
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1 - x1.mean()
x2 = x2 - x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
assert len(y1) == 40, "only for openmic, works this"
y_mask1 = (torch.as_tensor(y1[20:]) > 0.5).float()
y_mask2 = (torch.as_tensor(y2[20:]) > 0.5).float()
y1[:20] *= y_mask1
y2[:20] *= y_mask2
yres = (y1 * l + y2 * (1. - l))
yres[20:] = torch.stack([y_mask1, y_mask2]).max(dim=0).values
return x, f1, yres
return x1, f1, y1
def __len__(self):
return len(self.dataset)
class AudioSetDataset(TorchDataset):
def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False):
"""
Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
"""
self.sample_rate = sample_rate
self.hdf5_file = hdf5_file
if in_mem:
print("\nPreloading in memory\n")
with open(hdf5_file, 'rb') as f:
self.hdf5_file = io.BytesIO(f.read())
with h5py.File(hdf5_file, 'r') as f:
self.length = len(f['audio_name'])
print(f"Dataset from {hdf5_file} with length {self.length}.")
self.dataset_file = None # lazy init
self.clip_length = clip_length * sample_rate
self.classes_num = classes_num
self.augment = augment
if augment:
print(f"Will agument data from {hdf5_file}")
def open_hdf5(self):
self.dataset_file = h5py.File(self.hdf5_file, 'r')
def __len__(self):
return self.length
def __del__(self):
if self.dataset_file is not None:
self.dataset_file.close()
self.dataset_file = None
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
meta: {
'hdf5_path': str,
'index_in_hdf5': int}
Returns:
data_dict: {
'audio_name': str,
'waveform': (clip_samples,),
'target': (classes_num,)}
"""
if self.dataset_file is None:
self.open_hdf5()
audio_name = self.dataset_file['audio_name'][index].decode()
waveform = decode_mp3(self.dataset_file['mp3'][index])
if self.augment:
waveform = pydub_augment(waveform)
waveform = pad_or_truncate(waveform, self.clip_length)
waveform = self.resample(waveform)
target = self.dataset_file['target'][index]
target = target.astype(np.float32)
return waveform.reshape(1, -1), audio_name, target
def resample(self, waveform):
"""Resample.
Args:
waveform: (clip_samples,)
Returns:
(resampled_clip_samples,)
"""
if self.sample_rate == 32000:
return waveform
elif self.sample_rate == 16000:
return waveform[0:: 2]
elif self.sample_rate == 8000:
return waveform[0:: 4]
else:
raise Exception('Incorrect sample rate!')
@dataset.command
def get_base_training_set(openmic_train_hdf5):
ds = AudioSetDataset(openmic_train_hdf5, augment=True)
return ds
@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
epoch_len=100000, sampler_replace=False):
num_nodes = int(os.environ.get('num_nodes', 1))
ddp = int(os.environ.get('DDP', 1))
num_nodes = max(ddp, num_nodes)
print("num_nodes= ", num_nodes)
rank = int(os.environ.get('NODE_RANK', 0))
return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
num_samples=epoch_len, replacement=sampler_replace),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
@dataset.command
def get_base_test_set(openmic_test_hdf5):
ds = AudioSetDataset(openmic_test_hdf5)
return ds
@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
print("rolling...")
def roll_func(b):
x, i, y = b
x = torch.as_tensor(x)
sf = shift
if shift is None:
sf = int(np.random.random_integers(-shift_range, shift_range))
global FirstTime
return x.roll(sf, axis), i, y
return roll_func
@dataset.command
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_test_set(normalize):
ds = get_base_test_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def print_conf(_config):
print("Config of ", dataset.path, id(dataset))
print(_config)
print()
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset,
num_replicas=None,
rank=None,
shuffle: bool = True):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch == 0:
print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n")
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
if __name__ == "__main__":
from sacred import Experiment
ex = Experiment("test_dataset", ingredients=[dataset])
@ex.automain
def default_command():
ex.current_run.get_command_function("print_config")()
get_base_training_set()
ds = get_test_set()
print(ds[0])
ds = get_training_set()
print(ds[0])
print("get_base_training_set", len(get_base_training_set()))
print("get_base_test_set", len(get_base_test_set()))
print("get_training_set", len(get_training_set()))
print("get_test_set", len(get_test_set()))
================================================
FILE: openmic/prepare_scripts/download_preprocess.py
================================================
import os
import tarfile
import multiprocessing
import glob
import h5py
import numpy as np
from torch.hub import download_url_to_file
# global constants
openmicurl = "https://zenodo.org/record/1432913/files/openmic-2018-v1.0.0.tgz?download=1"
download_target = "openmic-2018-v1.0.0.tgz"
extract_target = download_target.replace(".tgz", "")
dataset_path = os.path.join(extract_target, "openmic-2018/")
mp3_path = os.path.join(dataset_path, "mp3/")
hdf5s_dir = "../../audioset_hdf5s/"
train_files_csv = os.path.join(dataset_path, "partitions/split01_train.csv")
test_files_csv = os.path.join(dataset_path, "partitions/split01_test.csv")
def download(force=False):
if force or not os.path.isfile(download_target):
print("Downloading OpenMIC from zenodo...")
download_url_to_file(openmicurl, download_target)
else:
print(f"{download_target} already exists. Skipping download!")
def untar():
my_tar = tarfile.open(download_target)
print(f"Extracting openmic from {download_target} to {extract_target}")
my_tar.extractall(extract_target)
def process_folder(fol="balanced_train_segments"):
print("now working on ", fol)
os.makedirs(mp3_path + fol, exist_ok=True)
all_files = list(glob.glob(os.path.join(dataset_path, "audio/") + "/*/*.ogg")) # openmic format
print(f"it has {len(all_files)}")
print(all_files[:5])
global all_num
all_num = len(all_files)
cmds = [(i, file, mp3_path + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]
print(cmds[0])
with multiprocessing.Pool(processes=20) as pool:
pool.starmap(process_one, cmds)
def process_one(i, f1, f2):
if i % 100 == 0:
print(f"{i}/{all_num} \t", f1)
os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3")
def make_mp3():
process_folder("audio")
def read_metadata(csv_path, classes_num, id_to_ix, openmicf):
"""Read metadata of AudioSet from a csv file.
source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59
Args:
csv_path: str
Returns:
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
"""
with open(csv_path) as file:
lines = file.readlines()
lines = [line.rstrip() for line in lines]
audios_num = len(lines)
# class + mask
targets = np.zeros((audios_num, classes_num * 2), dtype=np.float32)
audio_names = []
notfound = set()
for i, row in enumerate(lines):
audio_name = '{}.mp3'.format(row) # Audios are started with an extra 'Y' when downloading
audio_names.append(audio_name)
t = openmicf.Y_true[id_to_ix[row]] #id_to_ix[row]["true"]
m = openmicf.Y_mask[id_to_ix[row]].astype(int)#id_to_ix[row]["mask"]
# Target
targets[i, :classes_num] = t
targets[i, classes_num:] = m
print(notfound)
print("original_targets", len(targets))
mask = targets.astype(np.int).sum(1) > 0
print(len(mask), mask.sum())
print("after: ", len(targets[mask]))
meta_dict = {'audio_name': np.array(audio_names)[mask], 'target': targets[mask]}
return meta_dict
def get_files_labels(balanced_csv, balanced_audio_path, d_files, openmicf, prefix=None, zip_contents=None, classes_num=20):
meta_csv = read_metadata(balanced_csv, classes_num, d_files, openmicf)
# fname,labels,mids
audios_num = len(meta_csv['audio_name'])
found = 0
notfound = 0
available_files = []
available_targets = []
for n in range(audios_num):
audio_path = meta_csv['audio_name'][n]
# print(balanced_audio_path + f"{prefix}/{audio_path}")
if n == 0:
print("checking: ", balanced_audio_path + f"{prefix}/{audio_path}")
if os.path.isfile(balanced_audio_path + f"{prefix}/{audio_path}"):
found += 1
available_files.append(meta_csv['audio_name'][n])
available_targets.append(meta_csv['target'][n])
else:
notfound += 1
print(f"Found {found} . not found {notfound}")
return available_files, available_targets
def pack():
d_files = dict()
opmic = np.load(os.path.join(dataset_path, "openmic-2018.npz"))
opmic.allow_pickle = True
for i, sid in enumerate(opmic.f.sample_key):
d_files[sid] = i #{"mask": opmic.f.Y_mask[i].astype(int),
#"true": opmic.f.Y_true[i]}
print("len=",len(d_files))
for read_file, prefix in [(train_files_csv, "audio/"), (test_files_csv, "audio/")]:
print("now working on ", read_file, prefix)
# files, y = torch.load(read_file+".pth")
files, y = get_files_labels(read_file, mp3_path, d_files=d_files, openmicf=opmic.f, prefix=prefix)
y = np.array(y)
# y = np.packbits(y, axis=-1)
packed_len = y.shape[1]
print(files[0], "classes: ", packed_len, y.dtype)
available_size = len(files)
f = files[0]
a = np.fromfile(mp3_path + prefix + "/" + f, dtype='uint8')
dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = read_file.rsplit("/", 1)[1].replace("split01", "openmic")
os.makedirs(hdf5s_dir + "mp3/" ,exist_ok=True)
if os.path.isfile(hdf5s_dir + "mp3/" + save_file + "_mp3.hdf"):
print(hdf5s_dir + "mp3/" + save_file + "_mp3.hdf", "exists!\n\n\n contiue")
continue
with h5py.File(hdf5s_dir + "mp3/" + save_file + "_mp3.hdf", 'w') as hf:
audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
for i, file in enumerate(files):
if i % 1000 == 0:
print(f"{i}/{available_size}")
f = file
a = np.fromfile(mp3_path + prefix + f, dtype='uint8')
audio_name[i] = f
waveform[i] = a
target[i] = y[i]
print("Saved h5py file into ", hdf5s_dir + "mp3/" + save_file + "_mp3.hdf")
print(a.shape)
print("Done!", prefix)
def preprocess():
download()
untar()
make_mp3()
pack()
preprocess()
================================================
FILE: pip_list.txt
================================================
Package Version
----------------------- ------------
absl-py 1.4.0
aiohttp 3.8.4
aiosignal 1.3.1
appdirs 1.4.4
async-timeout 4.0.2
attrs 22.2.0
audioread 3.0.0
autopep8 1.6.0
av 10.0.0
brotlipy 0.7.0
cachetools 5.3.0
certifi 2022.12.7
cffi 1.15.1
charset-normalizer 2.0.4
click 8.1.3
colorama 0.4.6
cryptography 39.0.1
decorator 5.1.1
docker-pycreds 0.4.0
docopt 0.6.2
flit_core 3.8.0
frozenlist 1.3.3
fsspec 2023.3.0
future 0.18.3
gitdb 4.0.10
GitPython 3.1.31
google-auth 2.17.1
google-auth-oauthlib 0.4.6
grpcio 1.53.0
h5py 3.8.0
idna 3.4
importlib-metadata 6.1.0
joblib 1.2.0
jsonpickle 3.0.1
kk-sacred 0.8.4
lazy_loader 0.2
librosa 0.10.0.post2
lightning-utilities 0.8.0
llvmlite 0.39.1
Markdown 3.4.3
MarkupSafe 2.1.2
mkl-fft 1.3.1
mkl-random 1.2.2
mkl-service 2.4.0
msgpack 1.0.5
multidict 6.0.4
munch 2.5.0
numba 0.56.4
numpy 1.23.5
oauthlib 3.2.2
packaging 23.0
pandas 1.5.3
pathtools 0.1.2
Pillow 9.4.0
pip 23.0.1
pooch 1.6.0
protobuf 4.22.1
psutil 5.9.4
py-cpuinfo 9.0.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pycodestyle 2.10.0
pycparser 2.21
pyDeprecate 0.3.0
pyOpenSSL 23.0.0
PySocks 1.7.1
python-dateutil 2.8.2
pytorch-lightning 1.3.1
pytz 2023.3
PyYAML 5.4.1
requests 2.28.1
requests-oauthlib 1.3.1
rsa 4.9
scikit-learn 1.2.2
scipy 1.10.1
sentry-sdk 1.19.1
setproctitle 1.3.2
setuptools 65.6.3
six 1.16.0
smmap 5.0.0
soundfile 0.12.1
soxr 0.3.4
tensorboard 2.12.0
tensorboard-data-server 0.7.0
tensorboard-plugin-wit 1.8.1
threadpoolctl 3.1.0
timm 0.4.12
toml 0.10.2
torch 1.11.0
torchaudio 0.11.0
torchmetrics 0.2.0
torchvision 0.12.0
tqdm 4.65.0
typing_extensions 4.4.0
urllib3 1.26.15
wandb 0.14.2
Werkzeug 2.2.3
wheel 0.38.4
wrapt 1.15.0
yarl 1.8.2
zipp 3.15.0
================================================
FILE: requirements.txt
================================================
av>=10.0.0
h5py>=3.8.0
jsonpickle>=3.0.1
kk-sacred>=0.8.4
librosa>=0.10.0.post2
timm>=0.4.12
torchmetrics>=0.2.0
pytorch-lightning<2.0.0
wandb>=0.14.2