Showing preview only (358K chars total). Download the full file or copy to clipboard to get everything.
Repository: kkoutini/PaSST
Branch: main
Commit: 2a5c818afcc2
Files: 52
Total size: 341.1 KB
Directory structure:
gitextract_blhd6h9_/
├── .github/
│ └── workflows/
│ └── pylint.yml
├── .gitignore
├── .markdownlint.json
├── LICENSE
├── README.md
├── __init__.py
├── audioset/
│ ├── README.md
│ ├── dataset.py
│ └── prepare_scripts/
│ ├── convert_to_mp3.py
│ └── create_h5pymp3_dataset.py
├── ba3l/
│ ├── __init__.py
│ ├── experiment.py
│ ├── ingredients/
│ │ ├── __init__.py
│ │ ├── datasets.py
│ │ ├── ingredient.py
│ │ ├── models.py
│ │ └── trainer.py
│ ├── module.py
│ ├── plutils/
│ │ ├── __init__.py
│ │ ├── lr_monitor.py
│ │ └── progress_bar.py
│ └── util/
│ ├── __init__.py
│ ├── functions.py
│ └── sacred_logger.py
├── config_updates.py
├── environment.yml
├── environment_old_2021.yml
├── esc50/
│ ├── README.md
│ └── dataset.py
├── ex_audioset.py
├── ex_esc50.py
├── ex_fsd50k.py
├── ex_openmic.py
├── fsd50k/
│ ├── README.md
│ ├── dataset.py
│ └── prepare_scripts/
│ ├── convert_to_mp3.py
│ └── create_h5pymp3_dataset.py
├── helpers/
│ ├── audiodatasets.py
│ ├── mixup.py
│ ├── models_size.py
│ ├── ramp.py
│ ├── swa_callback.py
│ ├── swa_legacy.py
│ └── workersinit.py
├── models/
│ ├── helpers/
│ │ └── vit_helpers.py
│ ├── passt.py
│ └── preprocess.py
├── openmic/
│ ├── README.md
│ ├── dataset.py
│ └── prepare_scripts/
│ └── download_preprocess.py
├── pip_list.txt
└── requirements.txt
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/pylint.yml
================================================
name: Pylint
on: [push]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
================================================
FILE: .gitignore
================================================
# project
/output/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
example.py
timit_data/
LJSpeech-1.1/
# C extensions
*.so
.idea/
# Distribution / packaging
.Python
env/
ide_layouts/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# dotenv
.env
# virtualenv
.venv
venv/
ENV/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
/openmic/prepare_scripts/openmic-2018-v1.0.0/
/openmic/prepare_scripts/openmic-2018-v1.0.0.tgz
/audioset_hdf5s/mp3/openmic_test.csv_mp3.hdf
/audioset_hdf5s/mp3/openmic_train.csv_mp3.hdf
environment_builds.yml
# Output
lightning_logs/*
audioset_hdf5s/*
.vscode/settings.json
wandb/*
.vscode/launch.json
================================================
FILE: .markdownlint.json
================================================
{
"MD033": {
"allowed_elements": [
"p",
"img"
]
},
"MD013": false
}
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# PaSST: Efficient Training of Audio Transformers with Patchout
This is the implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069)
Patchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.
<p align="center"><img src="https://github.com/kkoutini/PaSST/blob/main/.github/speed_mem_map.png?raw=true" width="600"/></p>
Patchout works by dropping out some of the input patches during training.
In either an unstructured way (randomly, similar to dropout),
or entire time-frames or frequency bins of the extracted patches (similar to SpecAugment),
which corresponds to rows/columns in step 3 of the figure below.
<p align="center"><img src="https://github.com/kkoutini/PaSST/raw/main/.github/passt_diag.png?raw=true" width="600"/></p>
## Table of contents
- [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions)
- [Getting the logits from the pretrained models](#getting-the-logits-from-the-pretrained-models)
- [Getting a pre-trained model for fine-tuning](#getting-a-pre-trained-model-for-fine-tuning)
- [Development environment](#development-environment)
- [Setting up the development experiments environment](#setting-up-the-development-experiments-environment)
- [Setting up using the exported conda environment](#setting-up-using-the-exported-conda-environment)
- [Checking the environment](#checking-the-environment)
- [Getting started](#getting-started)
- [General information](#general-information)
- [Configuring the experiment](#configuring-the-experiment)
- [Training on Audioset](#training-on-audioset)
- [Examples with Pre-trained models](#examples-with-pre-trained-models)
- [Examples fine-tuning on downstream datasets](#examples-of-fine-tuning-on-downstream-datasets)
- [Citation](#citation)
- [Contact](#contact)
## Pre-trained models for Inference and embeddings extractions
If you only want to use the embeddings generated by the pretrained models, use
your own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo [here](https://github.com/kkoutini/passt_hear21).
The package follows [HEAR 2021 NeurIPS Challenge](https://neuralaudio.ai/hear2021-results.html) API, and can be installed:
```shell
pip install hear21passt
```
This repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks.
### Getting the logits from the pretrained models
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
print(model.net) # the transformer network.
# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
# audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
# example audio_wave of batch=3 and 10 seconds
audio = torch.ones((3, 32000 * 10))*0.5
audio_wave = audio.cuda()
logits=model(audio_wave)
```
### Getting a pre-trained model for fine tuning
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50)
print(model.net) # the transformer network.
# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k
model.train()
model = model.cuda()
```
## Development environment
If you want to use the same environment as in the paper, you can follow the instructions below.
### Setting up the development experiments environment
For training models from scratch or fine-tuning using the same setup as in the paper:
1. If needed, create a new environment with python 3.8 and activate it:
```bash
conda create -n passt python=3.8
conda activate passt
```
1. Install pytorch build that suits your system. For example:
```bash
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
```
1. Install the requirements:
```bash
pip install -r requirements.txt
```
### Setting up using the exported conda environment
Alternatively, you can use the exported conda environment `environment.yml` to create the environment.
For setting up [Mamba](https://github.com/mamba-org/mamba) is recommended since it works faster than `conda`:
```shell
conda install mamba -n base -c conda-forge
```
Now you can import the environment from `environment.yml`
```shell
mamba env create -f environment.yml
```
Now you have an environment named `ba3l`.
### Checking the environment
In order to check if your environment matched the environment we used in our runs, please check the `environment.yml` and `pip_list.txt` files, which were exported using:
```shell
conda env export --no-builds | grep -v "prefix" > environment.yml
pip list > pip_list.txt
```
## Getting started
If you want to use your setup and only use the models from this repo, you can get the models train them from scratch or fine-tune them on your own dataset as explained above [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions). The rest of this section explains using this repo for training and fine-tuning the models. For that, first you need to set up the development environment as explained above.
### General information
The repo is built using [sacred](https://sacred.readthedocs.io/en/) for experiment management and configuration, pytorch-lightning for training, and wandb for logging.
Each dataset has a main experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder. The experiment file contains the main training and validation logic. The dataset folder contains the dataset specific code needed to download, preprocess and load the dataset for training.
In general, you can prob the experiment file for help, this will print the available commands and basic options:
```shell
python ex_audioset.py help
```
### Configuring the experiment
Each experiment has a set of default configuration options, defined in the experiment file, e.g. `ex_audioset.py`. You can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html). You can use the `print_config` command to print the configuration values without training a model:
```shell
python ex_audioset.py print_config
```
You can use then use the command line interface to override any of the configuration options ([sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html)), using `with` e.g.:
```shell
python ex_audioset.py with trainer.precision=16
```
This will train on Audioset using 16-bit precision.
The overall configurations look like this:
```yaml
...
seed = 542198583 # the random seed for this experiment
slurm_job_id = ''
speed_test_batch_size = 100
swa = True
swa_epoch_start = 50
swa_freq = 5
use_mixup = True
warm_up_len = 5
weight_decay = 0.0001
basedataset:
base_dir = 'audioset_hdf5s/' # base directory of the dataset, change it or make a link
eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
wavmix = 1
....
roll_conf:
axis = 1
shift = None
shift_range = 50
datasets:
test:
batch_size = 20
dataset = {CMD!}'/basedataset.get_test_set'
num_workers = 16
validate = True
training:
batch_size = 12
dataset = {CMD!}'/basedataset.get_full_training_set'
num_workers = 16
sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
shuffle = None
train = True
models:
mel:
freqm = 48
timem = 192
hopsize = 320
htk = False
n_fft = 1024
n_mels = 128
norm = 1
sr = 32000
...
net:
arch = 'passt_s_swa_p16_128_ap476'
fstride = 10
in_channels = 1
input_fdim = 128
input_tdim = 998
n_classes = 527
s_patchout_f = 4
s_patchout_t = 40
tstride = 10
u_patchout = 0
...
trainer:
accelerator = None
accumulate_grad_batches = 1
amp_backend = 'native'
amp_level = 'O2'
auto_lr_find = False
auto_scale_batch_size = False
...
```
There are many things that can be updated from the command line.
In short:
- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line.
- `wandb` is the wandb configuration. For example, to change the wandb project `wandb.project="test_project"` to the command line.
- `models.net` are the PaSST (or the chosen NN) options. Examples: `models.net.u_patchout`, `models.net.s_patchout_f` `models.net.s_patchout_t` control the unstructured patchout and structured patchout over frequency and time. `input_fdim` and `input_tdim` are the input spectrogram dimensions over frequency and time. `models.net.fstride` and `models.net.tstride` are the strides of the input patches over frequency and time, setting these to 16 means no patch overlap.
- `models.mel` are the preprocessing options (mel spectrograms). `mel.sr` is the sampling rate, `mel.hopsize` is the hop size of the STFT window, `mel.n_mels` is the number of mel bins, `mel.freqm` and `mel.timem` are the frequency and time masking parameters of spec-augment.
There are many pre-defined configuration bundles (called named_configs) in `config_updates.py`. These include different models, setups etc...
You can list these configurations with:
```shell
python ex_audioset.py print_named_configs
```
For example, `passt_s_20sec` is a configuration bundle that sets the model to PaSST-S pre-trained on Audioset, and accepts up to 20 second clips.
## Training on Audioset
Download and prepare the dataset as explained in the [audioset page](audioset/)
The base PaSST model can be trained for example like this:
```bash
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p
```
For example using only unstructured patchout of 400:
```bash
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 models.net.u_patchout=400 models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p
```
Multi-gpu training can be enabled by setting the environment variable `DDP`, for example with 2 gpus:
```shell
DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -c "PaSST base 2 GPU"
```
## Examples with Pre-trained models
Please check the [releases page](https://github.com/kkoutini/PaSST/releases/), to download pre-trained models.
In general, you can get a pretrained model on Audioset using
```python
from models.passt import get_model
model = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1,
fstride=10, tstride=10,input_fdim=128, input_tdim=998,
u_patchout=0, s_patchout_t=40, s_patchout_f=4)
```
this will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs.
There are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`.
For example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479.
In general, you can get a pretrained model using:
```python
from models.passt import get_model
passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)
```
Using the framework, you can evaluate this model using:
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 passt_s_swa_p16_s16_128_ap473 -p
```
Ensemble of these models are provided as well:
A large ensemble giving `mAP=.4956`
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
```
An ensemble of 2 models with `stride=14` and `stride=16` giving `mAP=.4858`
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
```
As well as other ensembles `ensemble_4`, `ensemble_5`
## Examples of fine-tuning on downstream datasets
1. [ESC-50: Dataset for Environmental Sound Classification](esc50/)
2. [OpenMIC-2018 dataset](openmic/)
3. [FSD50K](fsd50k/)
## Citation
The citation to the accepted paper in Interspeech 2022:
```bib
@inproceedings{koutini22passt,
author = {Khaled Koutini and
Jan Schl{\"{u}}ter and
Hamid Eghbal{-}zadeh and
Gerhard Widmer},
title = {Efficient Training of Audio Transformers with Patchout},
booktitle = {Interspeech 2022, 23rd Annual Conference of the International Speech
Communication Association, Incheon, Korea, 18-22 September 2022},
pages = {2753--2757},
publisher = {{ISCA}},
year = {2022},
url = {https://doi.org/10.21437/Interspeech.2022-227},
doi = {10.21437/Interspeech.2022-227},
}
```
## Contact
The repo will be updated, in the meantime if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.
================================================
FILE: __init__.py
================================================
================================================
FILE: audioset/README.md
================================================
# Experiments on Audioset
Audioset has around 2M segments. The total size of the dataset with wav files with `32khz` sampling rate is around 1.2 TB. In our setup, this results in a huge IO bottleneck that slows down the training process significantly.
Therefore, we encode the dataset to mp3, pack the mp3 into HDF5 format and decode the mp3s on the fly, If you have enough cpu cores (10-16 dataloading workers) you should not notice any slowdowns.
in the `dataset.py` file we read the samples from the hdf files. Decode the mp3, do wave form augmentations and return the raw waveform of the model.
`AudioSetDataset` is the main class where reading from the hdf files.
## Preparing the dataset
### Downloading Audioset
We used the scripts provided by [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn) to download the dataset.
### Converting to mp3
Once the Datasets are downloaded we convert all the files to mp3 using the script:
`prepare_scripts/convert_to_mp3.py`.
```bash
python convert_to_mp3.py --source pann_download_folder --out mp3_folder
```
this will significantly reduce the size of the dataset and overcome the IO bottleneck in our setup. The trade-off is that more cpu is needed during training to decode the mp3s.
We use the [av](https://pypi.org/project/av/) (check `decode_mp3` in ` dataset.py`) library to decode the mp3 in the data loading workers, this is much faster than calling ffmpeg.
As a result, approximetly 10 decoding threads should be enough keep a 2080ti busy.
you can test how much time it take to load and decode one epoch on your system:
```bash
python python ex_audioset.py test_loaders_train_speed
```
This step is not necessary if you have a more powerful setup and the `decode_mp3` also supports other ffmpeg codecs.
### packing to HDF5 files
Finally, you need to pack the mp3 files into a single HDF5 file using `create_h5pymp3_dataset.py`.
you just need to set the paths in the script to match your local paths. The script goes through the csv files and check if the corresponding mp3 file exists, then it will store it in h5py file.
The output of this step should be 3 files `balanced_train_segments_mp3.hdf`, `eval_segments_mp3.hdf` and `unbalanced_train_segments_mp3.hdf`.
Each of these files. Make sure the paths match the default config in `dataset.py`
================================================
FILE: audioset/dataset.py
================================================
import io
import os
import pathlib
import random
import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler
import torch
from ba3l.ingredients.datasets import Dataset
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
import numpy as np
from helpers.audiodatasets import PreprocessDataset
import h5py
LMODE = os.environ.get("LMODE", False)
#$TMPDIR
dataset = Dataset('audiodataset')
@dataset.config
def default_config():
name = 'audioset' # dataset name
normalize = False # normalize dataset
subsample = False # subsample squares from the dataset
roll = True # apply roll augmentation
fold = 1
base_dir = "audioset_hdf5s/" # base directory of the dataset, change it or make a link
if LMODE:
base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/"
balanced_train_hdf5 = base_dir + "mp3/balanced_train_segments_mp3.hdf"
eval_hdf5 = base_dir + "mp3/eval_segments_mp3.hdf"
unbalanced_train_hdf5 = base_dir + "mp3/unbalanced_train_segments_mp3.hdf"
if LMODE:
balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
unbalanced_train_hdf5 = unbalanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
ir_path = base_dir + "irs/"
num_of_classes = 527
if LMODE:
@dataset.config
def LMODE_default_config():
cache_root_path = "/system/user/publicdata/CP/DCASE/cached_datasets/"
def decode_mp3(mp3_arr):
"""
decodes an array if uint8 representing an mp3 file
:rtype: np.array
"""
container = av.open(io.BytesIO(mp3_arr.tobytes()))
stream = next(s for s in container.streams if s.type == 'audio')
# print(stream)
a = []
for i, packet in enumerate(container.demux(stream)):
for frame in packet.decode():
a.append(frame.to_ndarray().reshape(-1))
waveform = np.concatenate(a)
if waveform.dtype != 'float32':
raise RuntimeError("Unexpected wave type")
return waveform
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
else:
return x[0: audio_length]
irs_arr = None
@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
if not ir_augment:
return
global irs_arr
if irs_arr is None:
all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
if cut_irs_offset is not None:
all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
print("will use these IRs:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])
_run.info["ir_devices"] = all_paths_name
irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
return irs_arr[int(np.random.randint(0, len(irs_arr)))]
@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
if ir_augment and torch.rand(1) < ir_augment:
ir = get_ir_sample()
waveform = convolve(waveform, ir, 'full')
if gain_augment:
gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
amp = 10 ** (gain / 20)
waveform = waveform * amp
return waveform
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
if torch.rand(1) < self.rate:
x1, f1, y1 = self.dataset[index]
idx2 = torch.randint(len(self.dataset), (1,)).item()
x2, f2, y2 = self.dataset[idx2]
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1-x1.mean()
x2 = x2-x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
return x, f1, (y1 * l + y2 * (1. - l))
return self.dataset[index]
def __len__(self):
return len(self.dataset)
class AudioSetDataset(TorchDataset):
def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False):
"""
Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
"""
self.sample_rate = sample_rate
self.hdf5_file = hdf5_file
if in_mem:
print("\nPreloading in memory\n")
with open(hdf5_file, 'rb') as f:
self.hdf5_file = io.BytesIO(f.read())
with h5py.File(hdf5_file, 'r') as f:
self.length = len(f['audio_name'])
print(f"Dataset from {hdf5_file} with length {self.length}.")
self.dataset_file = None # lazy init
self.clip_length = clip_length * sample_rate
self.classes_num = classes_num
self.augment = augment
if augment:
print(f"Will agument data from {hdf5_file}")
def open_hdf5(self):
self.dataset_file = h5py.File(self.hdf5_file, 'r')
def __len__(self):
return self.length
def __del__(self):
if self.dataset_file is not None:
self.dataset_file.close()
self.dataset_file = None
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
meta: {
'hdf5_path': str,
'index_in_hdf5': int}
Returns:
data_dict: {
'audio_name': str,
'waveform': (clip_samples,),
'target': (classes_num,)}
"""
if self.dataset_file is None:
self.open_hdf5()
audio_name = self.dataset_file['audio_name'][index].decode()
waveform = decode_mp3(self.dataset_file['mp3'][index])
if self.augment:
waveform = pydub_augment(waveform)
waveform = pad_or_truncate(waveform, self.clip_length)
waveform = self.resample(waveform)
target = self.dataset_file['target'][index]
target = np.unpackbits(target, axis=-1,
count=self.classes_num).astype(np.float32)
return waveform.reshape(1, -1), audio_name, target
def resample(self, waveform):
"""Resample.
Args:
waveform: (clip_samples,)
Returns:
(resampled_clip_samples,)
"""
if self.sample_rate == 32000:
return waveform
elif self.sample_rate == 16000:
return waveform[0:: 2]
elif self.sample_rate == 8000:
return waveform[0:: 4]
else:
raise Exception('Incorrect sample rate!')
@dataset.command
def get_base_training_set(balanced_train_hdf5):
ds = AudioSetDataset(balanced_train_hdf5, augment=True)
return ds
@dataset.command
def get_unbalanced_training_set(unbalanced_train_hdf5):
ds = AudioSetDataset(unbalanced_train_hdf5, augment=True)
return ds
@dataset.command
def get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5):
ds = ConcatDataset(
[AudioSetDataset(balanced_train_hdf5, augment=False), AudioSetDataset(unbalanced_train_hdf5, augment=False)])
return ds
@dataset.command
def get_base_full_training_set():
sets = [get_base_training_set(), get_unbalanced_training_set()]
ds = ConcatDataset(sets)
return ds
@dataset.command
def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes):
for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
print(f"\n \n will now preload {hdf5_file} \n\n ")
with h5py.File(hdf5_file, 'r') as dataset_file:
target = dataset_file['mp3'][:]
print(len(target))
print(f"\n \n done with {hdf5_file} \n\n ")
return target[1000]
@dataset.command
def get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes,
sample_weight_offset=100, sample_weight_sum=True):
"""
:return: float tenosr of shape len(full_training_set) representing the weights of each sample.
"""
# the order of balanced_train_hdf5,unbalanced_train_hdf5 is important.
# should match get_full_training_set
all_y = []
for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
with h5py.File(hdf5_file, 'r') as dataset_file:
target = dataset_file['target']
target = np.unpackbits(target, axis=-1,
count=num_of_classes)
all_y.append(target)
all_y = np.concatenate(all_y, axis=0)
all_y = torch.as_tensor(all_y)
per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class
per_class = sample_weight_offset + per_class # offset low freq classes
if sample_weight_offset > 0:
print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}")
per_class_weights = 1000. / per_class
all_weight = all_y * per_class_weights
# print(all_weight.shape)
# print(all_weight[1510])
if sample_weight_sum:
print("\nsample_weight_sum\n")
all_weight = all_weight.sum(dim=1)
else:
all_weight, _ = all_weight.max(dim=1)
# print(all_weight.shape)
# print(all_weight[1510])
return all_weight
@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
epoch_len=100000, sampler_replace=False):
num_nodes = int(os.environ.get('num_nodes', 1))
ddp = int(os.environ.get('DDP', 1))
num_nodes = max(ddp, num_nodes)
print("num_nodes= ", num_nodes)
rank = int(os.environ.get('NODE_RANK', 0))
return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
num_samples=epoch_len, replacement=sampler_replace),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
@dataset.command
def get_base_test_set(eval_hdf5):
ds = AudioSetDataset(eval_hdf5)
return ds
@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
print("rolling...")
def roll_func(b):
x, i, y = b
x = torch.as_tensor(x)
sf = shift
if shift is None:
sf = int(np.random.random_integers(-shift_range, shift_range))
global FirstTime
return x.roll(sf, axis), i, y
return roll_func
@dataset.command
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_full_training_set(normalize, roll, wavmix=False):
ds = get_base_full_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_test_set(normalize):
ds = get_base_test_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def print_conf(_config):
print("Config of ", dataset.path, id(dataset))
print(_config)
print()
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset,
num_replicas=None,
rank=None,
shuffle: bool = True):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch == 0:
print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n")
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
if __name__ == "__main__":
from sacred import Experiment
ex = Experiment("test_dataset", ingredients=[dataset])
@ex.automain
def default_command():
ex.current_run.get_command_function("print_config")()
get_base_training_set()
ds = get_test_set()
print(ds[0])
ds = get_training_set()
print(ds[0])
print("get_base_training_set", len(get_base_training_set()))
print("get_base_test_set", len(get_base_test_set()))
print("get_training_set", len(get_training_set()))
print("get_test_set", len(get_test_set()))
================================================
FILE: audioset/prepare_scripts/convert_to_mp3.py
================================================
import argparse
import multiprocessing
import glob
import os
# Replace this with the dataset downloaded using PANN scripts
source_path = "/share/cp/datasets/full_audioset/audioset201906/audios/"
# Replace with the output directory
out_path = "./mp3_audio/"
all_num = 0
def process_folder(fol="balanced_train_segments"):
print("now working on ", fol)
os.makedirs(out_path + fol, exist_ok=True)
all_files = list(glob.glob(source_path + fol + "/*.wav"))
print(f"it has {len(all_files)}")
global all_num
all_num = len(all_files)
cmds = [(i, file, out_path + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]
print(cmds[0])
with multiprocessing.Pool(processes=20) as pool:
pool.starmap(process_one, cmds)
def process_one(i, f1, f2):
if i % 100 == 0:
print(f"{i}/{all_num} \t", f1)
os.system(f"ffmpeg -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--source', type=str, required=False, default=None,
help='Path of folder containing the wave files, expected format to be as downloaded '
'using the PANNs script, containing balanced_train_segments, eval_segments, '
'unbalanced_train_segments folders.')
parser.add_argument('--out', type=str, required=False, default=None,
help='Directory to save out the converted mp3s.')
args = parser.parse_args()
source_path = args.source or source_path
out_path = args.out or out_path
folders = ['balanced_train_segments', 'eval_segments'] + ["unbalanced_train_segments/" + x for x in sorted(
os.listdir(source_path + "unbalanced_train_segments/"))]
print("I will work on these folders:")
print(folders)
for fol in folders:
process_folder(fol)
================================================
FILE: audioset/prepare_scripts/create_h5pymp3_dataset.py
================================================
# %%
import h5py
import pandas as pd
import numpy as np
import csv
import os
# %%
base_dir = "../../audioset_hdf5s/"
balanced_csv= base_dir+ "metadata/balanced_train_segments.csv"
eval_csv= base_dir+ "metadata/eval_segments.csv"
mp3_path = "../../mp3_audio/"
# %%
def read_metadata(csv_path, classes_num, id_to_ix):
"""Read metadata of AudioSet from a csv file.
source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59
Args:
csv_path: str
Returns:
meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
"""
with open(csv_path, 'r') as fr:
lines = fr.readlines()
lines = lines[3:] # Remove heads
audios_num = len(lines)
targets = np.zeros((audios_num, classes_num), dtype=np.bool)
audio_names = []
for n, line in enumerate(lines):
items = line.split(', ')
"""items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""
audio_name = 'Y{}.mp3'.format(items[0]) # Audios are started with an extra 'Y' when downloading
label_ids = items[3].split('"')[1].split(',')
audio_names.append(audio_name)
# Target
for id in label_ids:
ix = id_to_ix[id]
targets[n, ix] = 1
meta_dict = {'audio_name': np.array(audio_names), 'target': targets}
return meta_dict
# Load label
with open(base_dir+'metadata/class_labels_indices.csv', 'r') as f:
reader = csv.reader(f, delimiter=',')
lines = list(reader)
labels = []
ids = [] # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
id = lines[i1][1]
label = lines[i1][2]
ids.append(id)
labels.append(label)
classes_num = len(labels)
lb_to_ix = {label : i for i, label in enumerate(labels)}
ix_to_lb = {i : label for i, label in enumerate(labels)}
id_to_ix = {id : i for i, id in enumerate(ids)}
ix_to_id = {i : id for i, id in enumerate(ids)}
# %%
def check_available(balanced_csv,balanced_audio_path,prefix=None):
meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix)
audios_num = len(meta_csv['audio_name'])
found=0
notfound=0
available_files=[]
available_targets=[]
if prefix is None:
prefix = os.path.basename(balanced_csv)[:-4]
for n in range(audios_num):
audio_path = meta_csv['audio_name'][n]
#print(balanced_audio_path + f"{prefix}/{audio_path}")
if os.path.isfile(balanced_audio_path + f"{prefix}/{audio_path}" ):
found+=1
available_files.append(meta_csv['audio_name'][n])
available_targets.append(meta_csv['target'][n])
else:
notfound+=1
print(f"Found {found} . not found {notfound}")
return available_files,available_targets
# %%
# %%
# %%
for read_file,prefix in [(balanced_csv,"balanced_train_segments/"), (eval_csv,"eval_segments/"),]:
print("now working on ",read_file,prefix)
#files, y = torch.load(read_file+".pth")
files, y = check_available(read_file, mp3_path)
y = np.packbits(y, axis=-1)
packed_len = y.shape[1]
print(files[0], "classes: ",packed_len, y.dtype)
available_size = len(files)
f = files[0][:-3]+"mp3"
a = np.fromfile(mp3_path+prefix + "/"+f, dtype='uint8')
dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = prefix.split("/")[0]
with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf:
audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
for i,file in enumerate(files):
if i%1000==0:
print(f"{i}/{available_size}")
f = file[:-3] + "mp3"
a = np.fromfile(mp3_path + prefix + f, dtype='uint8')
audio_name[i]=f
waveform[i] = a
target[i] = y[i]
print(a.shape)
print("Done!" , prefix)
# %%
print("working on unbalanced...")
all_x,all_y = None, None
for idx in range(41):
print("working on ",idx)
tmp_csv = base_dir+ f"metadata/unbalanced_partial_csvs/unbalanced_train_segments_part{idx:02}.csv"
prefix = f"unbalanced_train_segments/unbalanced_train_segments_part{idx:02}"
x,y = check_available(tmp_csv,mp3_path,prefix=prefix)
x = np.array([f"{idx:02}/"+one for one in x])
y=np.packbits(y, axis=-1)
print("x,y",x.shape, y.shape)
if all_x is None:
all_x = x
all_y = y
else:
all_x = np.concatenate((all_x,x))
all_y = np.concatenate((all_y,y))
print(f"done {idx}! all x,y",all_x.shape, all_y.shape)
print("now working on packing unbalanced")
prefix = "unbalanced_train_segments/unbalanced_train_segments_part"
files = all_x
y = all_y
packed_len = y.shape[1]
print(files[0], "classes: ",packed_len, y.dtype)
available_size = len(files)
f = files[0][:-3]+"mp3"
a = np.fromfile(mp3_path+prefix + f, dtype='uint8')
dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = prefix.split("/")[0]
with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf:
audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
for i,file in enumerate(files):
if i%1000==0:
print(f"{i}/{available_size}")
f = file[:-3] + "mp3"
a = np.fromfile(mp3_path + prefix + f, dtype='uint8')
audio_name[i]=f
waveform[i] = a
target[i] = y[i]
print(a.shape)
print("Done!" , prefix)
================================================
FILE: ba3l/__init__.py
================================================
"""Package info"""
__version__ = "0.0.2"
__author__ = "Koutini et al."
__license__ = "Apache-2.0"
__copyright__ = "Copyright (c) 2019-, %s." % __author__
__homepage__ = "https://github.com//"
__docs__ = (
"The researcher friendly pytorch environment. Ba3l= sacred+pytorch-lightening."
)
__author_email__ = "first.last@jku.at"
================================================
FILE: ba3l/experiment.py
================================================
import inspect
from importlib import import_module
from ba3l.ingredients.datasets import Datasets
from ba3l.ingredients.models import Models, Model
#from ba3l.trainer import Trainer
from ba3l.util.sacred_logger import SacredLogger
from sacred import Experiment as Sacred_Experiment, Ingredient
from typing import Sequence, Optional, List
from sacred.commandline_options import CLIOption
from sacred.config import CMD
from sacred.host_info import HostInfoGetter
from sacred.utils import PathType, optional_kwargs_decorator
from pytorch_lightning import loggers as pl_loggers
from ba3l.util.functions import get_default_kwargs_dict
def ingredients_recursive_apply(ing, fn):
fn(ing)
for kid in ing.ingredients:
ingredients_recursive_apply(kid, fn)
def config_recursive_apply(conf, fn):
for k,v in conf.items():
if isinstance(v, dict):
config_recursive_apply(v,fn)
else:
fn(k,v)
def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):
loggers = []
if use_sacred_logger:
loggers.append( SacredLogger(expr))
if use_tensorboard_logger:
loggers.append(pl_loggers.TensorBoardLogger(sacred_logger.name))
return loggers
class Experiment(Sacred_Experiment):
"""
Main Ba3l Experiment class overrides sacred experiments.
"""
def __init__(
self,
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
datasets: Optional[Ingredient] = None,
models: Optional[Ingredient] = None,
interactive: bool = False,
base_dir: Optional[PathType] = None,
additional_host_info: Optional[List[HostInfoGetter]] = None,
additional_cli_options: Optional[Sequence[CLIOption]] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients. (from Sacred)
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
if models is None:
models = Models.get_instance()
self.models = models
if datasets is None:
datasets = Datasets.get_instance()
self.datasets = datasets
if ingredients is None:
ingredients = []
ingredients = list(ingredients) + [models, datasets]
caller_globals = inspect.stack()[1][0].f_globals
self.last_default_configuration_position = 0
super().__init__(
name=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
additional_host_info=additional_host_info,
additional_cli_options=additional_cli_options,
save_git_info=save_git_info,
caller_globals=caller_globals
)
def get_run_identifier(self):
return str(self.current_run.db_identifier) \
+ "_" + str(self.current_run._id)
def get_dataloaders(self, filter={}):
results = {}
for ds in self.datasets.get_datasets(filter):
results[ds.name] = ds.get_iterator()
if len(results) == 1:
for k, v in results.items():
return v
return results
def get_train_dataloaders(self):
return self.get_dataloaders(dict(train=True))
def get_val_dataloaders(self):
return self.get_dataloaders(dict(validate=True))
def _create_run(
self,
command_name=None,
config_updates=None,
named_configs=(),
info=None,
meta_info=None,
options=None,
dry_run=False,
):
if self.current_run is not None:
# @todo replace with logger
print("Warning: multiple runs are not yet supported")
run = super()._create_run(
command_name,
config_updates,
named_configs,
info,
meta_info,
options,
dry_run=False,
)
# self.current_run=run
# def update_current_run(ing):
# ing.current_run = run
#
# ingredients_recursive_apply(self, update_current_run)
return run
@optional_kwargs_decorator
def command(
self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={},
**extra_args
):
"""
Decorator to define a new Command.
a command is a function whose parameters are filled automatically by sacred.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param add_default_args_config: wether to add the default arguments of the function to the config automatically.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
add_default_args_config = (not unobserved) and add_default_args_config
if add_default_args_config:
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[function.__name__] = captured_f
return captured_f
def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):
"""
adds the default parameters of a function to the ingredient config at lowest priority!
Default args config is meant remove the need to declare all the configurations manually.
:param f: the function
"""
# @todo get the doc of the params as well
config_candidate = {**get_default_kwargs_dict(function), **extra_args}
# remove "static_args" from config
for k in static_args:
config_candidate.pop(k, None)
# respect the prefix for the added default parameters
if prefix is not None:
for pr in prefix.split('.')[::-1]:
config_candidate={pr: config_candidate}
self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))
self.last_default_configuration_position += 1
================================================
FILE: ba3l/ingredients/__init__.py
================================================
================================================
FILE: ba3l/ingredients/datasets.py
================================================
import inspect
import os
from functools import partial
from ba3l.util.functions import get_default_kwargs_dict
from sacred.config import CMD
from .ingredient import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch
def raise_(ex):
raise ex
class Dataset(Ingredient):
"""
The class that annotates a Dateset of Ba3l experiment
a Dataset can be
"""
DATASET_STRING_PREFIX = "get_dataset"
ITER_STRING_PREFIX = "get_iterator"
def __init__(
self,
name: str,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "dataset"
self.name = name.rsplit(".", 1)[-1]
super().__init__(
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_dataset_command = None
self.get_dataset_iterator_command = None
self.current_run = None
self.get_dataset = lambda: raise_(
NotImplementedError(
"Use dataset.dataset_name.dataset to annotate the "
"get_dataset function!."
)
)
self.get_iter = lambda: raise_(
NotImplementedError(
"Use dataset.dataset_name.iter to annotate the " "get_iter function!."
)
)
@optional_kwargs_decorator
def dataset(
self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
):
"""
Decorator to define a new Dataset.
The name of the dataset is used to get an instance of the dataset, it will register a command
Datasets are sacred commands.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[Dataset.DATASET_STRING_PREFIX] = captured_f
self.get_dataset = captured_f
self.add_config(dataset=CMD("get_dataset"))
return captured_f
@optional_kwargs_decorator
def iter(
self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
):
"""
Decorator to define a new Iterator.
The name of the iterator is used to get an instance of the iterator, it will register a command
iterator are sacred commands.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
"""
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[Dataset.ITER_STRING_PREFIX] = captured_f
self.get_iter = captured_f
return captured_f
# def get_dataset(self):
# assert self.current_run is not None, "Can only be called during a run."
# return self.commands[Datasets.DATASET_STRING_PREFIX + name]()
# # return self.current_run.get_command_function(
# # self.path + "." + Datasets.DATASET_STRING_PREFIX + name)()
# #
def __getattr__(self, k):
if k == "iterator":
return self.__getattribute__("iter")
if k == "get_iterator":
return self.__getattribute__("get_iter")
super().__getattribute__(k)
# @todo maybe run commands from here after running
class Datasets(Ingredient, Munch):
"""
The class that encapsulates all the datasets in an experiment
"""
__instance = None
@classmethod
def get_instance(cls):
if Datasets.__instance is None:
Datasets.__instance = Datasets()
return Datasets.__instance
def __init__(
self,
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "datasets"
Ingredient.__init__(
self,
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_datasets_list_command = None
self.get_dataset_command = None
self.get_dataset_iterator_command = None
self.current_run = None
self.get_dataset = None
# self.command(get_dataset_iterator_command, unobserved=True)
def __getattr__(self, k):
""" Gets key if it exists, otherwise returns the default value."""
try:
return Munch.__getattr__(self, k)
except AttributeError:
return self.__getitem__(k)
def __setattr__(self, k, v):
try:
# Throws exception if not in prototype chain
object.__getattribute__(self, k)
except AttributeError:
try:
self[k] = v
except:
raise AttributeError(k)
else:
object.__setattr__(self, k, v)
def __getitem__(self, k):
""" Gets key if it exists, otherwise returns the default value."""
try:
return Munch.__getitem__(self, k)
except KeyError:
self[k] = Dataset(
self.path + "." + k,
base_dir=self.base_dir,
save_git_info=self.save_git_info,
)
assert self
self.ingredients.append(self[k])
return self[k]
def __hash__(self):
return Ingredient.__hash__(self)
def get_datasets(self, config_conditions={}, return_datasets_names=False):
""" Return all the datasets whose configuration matches config_conditions."""
results = []
for dataset in self.ingredients:
all_ok = True
for cond_k, cond_v in config_conditions.items():
if (
self.current_run.get_config_path_value(dataset.path + "." + cond_k)
!= cond_v
):
all_ok = False
break
if all_ok:
if return_datasets_names:
results.append((dataset))
results.append(dataset)
return results
================================================
FILE: ba3l/ingredients/ingredient.py
================================================
import inspect
import os
from functools import partial
from ba3l.util.functions import get_default_kwargs_dict
from sacred import Ingredient as sacred_Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch
def raise_(ex):
raise ex
class Ingredient(sacred_Ingredient):
"""
The class that annotates a Dateset of Ba3l experiment
a Dataset can be
"""
def __init__(
self,
path: str,
ingredients: Sequence[sacred_Ingredient] = (),
interactive: bool = False,
_caller_globals: Optional[dict] = None,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
The Base Ingredient of all Ba3l ingredients
Parameters
----------
path
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
_caller_globals = _caller_globals or inspect.stack()[1][0].f_globals
if path is None:
path = "Ingredient"
super().__init__(
path=path,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=_caller_globals,
save_git_info=save_git_info,
)
self.current_run = None
self.last_default_configuration_position = 0
def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):
"""
adds the default parameters of a function to the ingredient config at lowest priority!
Default args config is meant remove the need to declare all the configurations manually.
:param f: the function
"""
# @todo get the doc of the params as well
config_candidate = {**get_default_kwargs_dict(function), **extra_args}
# remove "static_args" from config
for k in static_args:
config_candidate.pop(k, None)
if prefix is not None:
for pr in prefix.split('.')[::-1]:
config_candidate={pr: config_candidate}
self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))
self.last_default_configuration_position += 1
@optional_kwargs_decorator
def command(
self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={}, **extra_args
):
"""
Decorator to define a new Command.
a command is a function whose parameters are filled automatically by sacred.
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
if add_default_args_config:
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[function.__name__] = captured_f
return captured_f
================================================
FILE: ba3l/ingredients/models.py
================================================
import inspect
import os
from .ingredient import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType
import inspect
import os
from functools import partial
from ba3l.util.functions import get_default_kwargs_dict
from sacred.config import CMD
from .ingredient import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch
class Models(Ingredient):
"""
The class that annotates the models of Ba3l experiment
"""
__instance = None
@classmethod
def get_instance(cls):
if Models.__instance is None:
Models.__instance = Models()
return Models.__instance
def __init__(
self,
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "models"
super().__init__(
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_models_command = None
self.current_run = None
# self.command(print_config, unobserved=True)
def raise_(ex):
raise ex
class Model(Ingredient):
"""
The class that annotates a Dateset of Ba3l experiment
a Dataset can be
"""
MODEL_STRING_PREFIX = "get_instance"
def __init__(
self,
name: str,
ingredients: Sequence[Ingredient] = (),
interactive: bool = False,
base_dir: Optional[PathType] = None,
save_git_info: bool = True,
):
"""
Create a new experiment with the given name and optional ingredients.
Parameters
----------
name
Optional name of this experiment, defaults to the filename.
(Required in interactive mode)
ingredients : list[sacred.Ingredient], optional
A list of ingredients to be used with this experiment.
interactive
If set to True will allow the experiment to be run in interactive
mode (e.g. IPython or Jupyter notebooks).
However, this mode is discouraged since it won't allow storing the
source-code or reliable reproduction of the runs.
base_dir
Optional full path to the base directory of this experiment. This
will set the scope for automatic source file discovery.
additional_host_info
Optional dictionary containing as keys the names of the pieces of
host info you want to collect, and as
values the functions collecting those pieces of information.
save_git_info:
Optionally save the git commit hash and the git state
(clean or dirty) for all source files. This requires the GitPython
package.
"""
caller_globals = inspect.stack()[1][0].f_globals
if name is None:
name = "model"
self.name = name.rsplit(".", 1)[-1]
super().__init__(
path=name,
ingredients=ingredients,
interactive=interactive,
base_dir=base_dir,
_caller_globals=caller_globals,
save_git_info=save_git_info,
)
self.get_instance_command = None
self.current_run = None
self.get_instance = lambda: raise_(
NotImplementedError(
"Use dataset.dataset_name.dataset to annotate the "
"get_dataset function!."
)
)
@optional_kwargs_decorator
def instance(
self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
):
"""
Decorator to define a new model.
The name of the model is used to get an instance of the model, it will register a command
The command can be given a prefix, to restrict its configuration space
to a subtree. (see ``capture`` for more information)
A command can be made unobserved (i.e. ignoring all observers) by
passing the unobserved=True keyword argument.
:param function: the function to return a Dataset Object
:param prefix: sacred configuration prefix
:param unobserved: sacred unobserved
:param static_args: static Args to be passed to the function, these arg need not to be serlizable and
are not stored in the config
:param extra_args: explicit arguments to be add to the config, you can these to override the function default
values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
specified by CMD string value. CMD string have special context
:return:
"""
self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
captured_f = self.capture(function, prefix=prefix, static_args=static_args)
captured_f.unobserved = unobserved
self.commands[Model.MODEL_STRING_PREFIX] = captured_f
self.get_instance = captured_f
self.add_config(get_instance=CMD("get_instance"))
return captured_f
def __getattr__(self, k):
if k == "get_instance":
return self.__getattribute__("get_instance")
super().__getattribute__(k)
# @todo maybe run commands from here after running
================================================
FILE: ba3l/ingredients/trainer.py
================================================
import inspect
import os
from ba3l.util.functions import get_default_kwargs_dict
from .datasets import Datasets, raise_
from .models import Models
from sacred import Ingredient
from typing import Sequence, Optional, List
from sacred.utils import PathType, optional_kwargs_decorator
class Trainer(Ingredient):
"""
The class that annotates the main Trainer of Ba3l experiment
"""
TRAINER_STRING_PREFIX = "get_trainer"
================================================
FILE: ba3l/module.py
================================================
import pytorch_lightning as pl
import warnings
from abc import ABC
import torch.distributed as dist
from munch import DefaultMunch
try:
# loading for pyTorch 1.3
from torch.utils.data import IterableDataset
except ImportError:
# loading for pyTorch 1.1
import torch
warnings.warn(
"Your version of pyTorch %s does not support `IterableDataset`,"
" please upgrade to 1.2+" % torch.__version__,
ImportWarning,
)
EXIST_ITER_DATASET = False
else:
EXIST_ITER_DATASET = True
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
class Ba3lModule(pl.LightningModule):
def __init__(self, experiment):
super(Ba3lModule, self).__init__()
self.experiment = experiment
self.run = experiment.current_run
self.config = DefaultMunch.fromDict(experiment.current_run.config)
for key,model in experiment.current_run.config['models'].items():
setattr(self, key, experiment.current_run.get_command_function("models."+key+"."+model['instance_cmd'])())
self.save_hyperparameters(self.config)
================================================
FILE: ba3l/plutils/__init__.py
================================================
================================================
FILE: ba3l/plutils/lr_monitor.py
================================================
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Learning Rate Monitor
=====================
Monitor and logs learning rate for lr schedulers during training.
"""
from typing import Dict, List, Optional
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class LearningRateMonitor(Callback):
r"""
Automatically monitor and logs learning rate for learning rate schedulers during training.
Args:
logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
at the same interval, set to ``None`` to log at individual interval
according to the ``interval`` key of each scheduler. Defaults to ``None``.
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
Raises:
MisconfigurationException:
If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateMonitor
>>> lr_monitor = LearningRateMonitor(logging_interval='step')
>>> trainer = Trainer(callbacks=[lr_monitor])
Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named ``Adam``,
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schdulers
Example::
def configure_optimizer(self):
optimizer = torch.optim.Adam(...)
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
'name': 'my_logging_name'
}
return [optimizer], [lr_scheduler]
"""
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
if logging_interval not in (None, 'step', 'epoch'):
raise MisconfigurationException('logging_interval should be `step` or `epoch` or `None`.')
self.logging_interval = logging_interval
self.log_momentum = log_momentum
self.lrs = None
self.lr_sch_names = []
def on_train_start(self, trainer, *args, **kwargs):
"""
Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
Raises:
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
raise MisconfigurationException(
'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.'
)
if not trainer.lr_schedulers:
rank_zero_warn(
'You are using `LearningRateMonitor` callback with models that'
' have no learning rate schedulers. Please see documentation'
' for `configure_optimizers` method.', RuntimeWarning
)
if self.log_momentum:
def _check_no_key(key):
return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers)
if _check_no_key('momentum') and _check_no_key('betas'):
rank_zero_warn(
"You have set log_momentum=True, but some optimizers do not"
" have momentum. This will log a value 0 for the momentum.", RuntimeWarning
)
# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)
# Initialize for storing values
self.lrs = {name: [] for name in names}
self.last_momentum_values = {name + "-momentum": None for name in names}
def on_train_batch_start(self, trainer, *args, **kwargs):
if not self._should_log(trainer):
return
if self.logging_interval != 'epoch':
interval = 'step' if self.logging_interval is None else 'any'
latest_stat = self._extract_stats(trainer, interval)
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
def on_train_epoch_start(self, trainer, *args, **kwargs):
if self.logging_interval != 'step':
interval = 'epoch' if self.logging_interval is None else 'any'
latest_stat = self._extract_stats(trainer, interval)
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
latest_stat = {}
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler['interval'] == interval or interval == 'any':
opt = scheduler['scheduler'].optimizer
param_groups = opt.param_groups
use_betas = 'betas' in opt.defaults
for i, pg in enumerate(param_groups):
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
latest_stat.update(lr)
momentum = self._extract_momentum(
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
)
latest_stat.update(momentum)
return latest_stat
def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
lr = param_group.get('lr')
self.lrs[name].append(lr)
return {name: lr}
def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
if not self.log_momentum:
return {}
momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
self.last_momentum_values[name] = momentum
return {name: momentum}
def _find_names(self, lr_schedulers) -> List[str]:
# Create uniqe names in the case we have multiple of the same learning
# rate schduler + multiple parameter groups
names = []
for scheduler in lr_schedulers:
sch = scheduler['scheduler']
if scheduler['name'] is not None:
name = scheduler['name']
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
i, name = 1, opt_name
# Multiple schduler of the same type
while True:
if name not in names:
break
i, name = i + 1, f'{opt_name}-{i}'
# Multiple param groups for the same schduler
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = f'{name}/pg{i + 1}'
names.append(temp)
else:
names.append(name)
self.lr_sch_names.append(name)
return names
@staticmethod
def _should_log(trainer) -> bool:
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
return should_log
================================================
FILE: ba3l/plutils/progress_bar.py
================================================
import importlib
import io
import os
import sys
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from typing import Optional, Union
from pytorch_lightning.callbacks import Callback,ProgressBarBase , ProgressBar as PlProgressBar
from pytorch_lightning.callbacks.progress import tqdm
class ProgressBar(PlProgressBar):
r"""
This is the default progress bar used by Lightning. It prints to `stdout` using the
:mod:`tqdm` package and shows up to four different bars:
- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for
multiple validation runs during training when
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation;
shows total progress over all validation datasets.
- **test progress:** only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the
:class:`~pytorch_lightning.trainer.trainer.Trainer`:
Example::
class LitProgressBar(ProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.set_description('running validation ...')
return bar
bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
Args:
refresh_rate:
Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display. By default, the
:class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress
bar and sets the refresh rate to the value provided to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
process_position:
Set this to a value greater than ``0`` to offset the progress bars by this many lines.
This is useful when you have progress bars defined elsewhere and want to show all of them
together. This corresponds to
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
:class:`~pytorch_lightning.trainer.trainer.Trainer`.
"""
def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.main_progress_bar is not None
has_main_bar = False
bar = tqdm(
desc='Validating',
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout
)
return bar
def on_epoch_start(self, trainer, pl_module):
ProgressBarBase.on_epoch_start(self, trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_val_batches = 0
total_batches = total_train_batches + total_val_batches
reset(self.main_progress_bar, total_batches)
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
ProgressBarBase.on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
def on_validation_start(self, trainer, pl_module):
ProgressBarBase.on_validation_start(self, trainer, pl_module)
if trainer.sanity_checking:
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
else:
if self.main_progress_bar is not None:
self.main_progress_bar.close()
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, self.total_val_batches)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
ProgressBarBase.on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
#self._update_bar(self.main_progress_bar)
def on_validation_end(self, trainer, pl_module):
ProgressBarBase.on_validation_end(self, trainer, pl_module)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.val_progress_bar.close()
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
return None
return x
def reset(bar: tqdm, total: Optional[int] = None) -> None:
""" Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """
if not bar.disable:
bar.reset(total=convert_inf(total))
================================================
FILE: ba3l/util/__init__.py
================================================
================================================
FILE: ba3l/util/functions.py
================================================
import inspect
from collections import OrderedDict
def get_default_kwargs_dict(f):
sig = inspect.signature(f)
return OrderedDict(
[
(p.name, p.default)
for p in sig.parameters.values()
if p.default != inspect._empty
]
)
================================================
FILE: ba3l/util/sacred_logger.py
================================================
from pytorch_lightning.utilities import rank_zero_only
try:
from pytorch_lightning.loggers import Logger as LightningLoggerBase
from pytorch_lightning.loggers.logger import rank_zero_experiment
except ImportError:
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from warnings import warn
from logging import getLogger
try:
import sacred
except ImportError:
raise ImportError("Missing sacred package. Run `pip install sacred`")
logger = getLogger(__name__)
class SacredLogger(LightningLoggerBase):
def __init__(self, sacred_experiment):
"""Initialize a sacred logger.
:param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers
already appended.
source: https://github.com/expectopatronum/pytorch-lightning/blob/9fcb238ec03e3f0b0378fd058119f1563a11650c/pytorch_lightning/logging/sacred.py
"""
super().__init__()
self.sacred_experiment = sacred_experiment
self.experiment_name = sacred_experiment.path
self._run_id = None
warn('SacredLogger is deprecated', DeprecationWarning, stacklevel=2)
@property
def experiment(self):
return self.sacred_experiment
@property
def run_id(self):
if self._run_id is not None:
return self._run_id
self._run_id = self.sacred_experiment.get_run_identifier()
return self._run_id
@rank_zero_only
def log_hyperparams(self, params):
# probably not needed bc. it is dealt with by sacred
pass
@rank_zero_only
def log_metrics(self, metrics, step=None):
for k, v in metrics.items():
if isinstance(v, str):
logger.warning(f"Discarding metric with string value {k}={v}")
continue
self.experiment.log_scalar(k, v, step)
@property
def name(self):
return self.experiment_name
@property
def version(self):
return self.run_id
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
================================================
FILE: config_updates.py
================================================
from sacred.config_helpers import DynamicIngredient, CMD
def add_configs(ex):
'''
This functions add generic configuration for the experiments, such as mix-up, architectures, etc...
@param ex: Ba3l Experiment
@return:
'''
@ex.named_config
def nomixup():
'Don\'t apply mix-up (spectrogram level).'
use_mixup = False
mixup_alpha = 0.3
@ex.named_config
def mixup():
' Apply mix-up (spectrogram level).'
use_mixup = True
mixup_alpha = 0.3
@ex.named_config
def mini_train():
'limit training/validation to 5 batches for debbuging.'
trainer = dict(limit_train_batches=5, limit_val_batches=5)
@ex.named_config
def passt():
'use PaSST model'
models = {
"net": DynamicIngredient("models.passt.model_ing")
}
@ex.named_config
def passt_s_20sec():
'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 20 seconds'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_20sec_p16_s10_ap474", fstride=10,
tstride=10, input_tdim=2000)
}
basedataset = dict(clip_length=20)
@ex.named_config
def passt_s_30sec():
'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 30 seconds'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_30sec_p16_s10_ap473", fstride=10,
tstride=10, input_tdim=3000)
}
basedataset = dict(clip_length=20)
@ex.named_config
def passt_s_ap476():
'use PaSST model pretrained on Audioset (with SWA) ap=476'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap476", fstride=10,
tstride=10)
}
@ex.named_config
def passt_s_ap4763():
'use PaSST model pretrained on Audioset (with SWA) ap=4763'
# test with: python ex_audioset.py evaluate_only with passt_s_ap4763
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap4763", fstride=10,
tstride=10)
}
@ex.named_config
def passt_s_ap472():
'use PaSST model pretrained on Audioset (no SWA) ap=472'
# test with: python ex_audioset.py evaluate_only with passt_s_ap472
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_128_ap472", fstride=10,
tstride=10)
}
@ex.named_config
def passt_s_p16_s16_128_ap468():
'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap'
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s16_128_ap468", fstride=16,
tstride=16)
}
@ex.named_config
def passt_s_swa_p16_s16_128_ap473():
'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap'
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s16_128_ap473", fstride=16,
tstride=16)
}
@ex.named_config
def passt_s_swa_p16_s14_128_ap471():
'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 '
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s14_128_ap471", fstride=14,
tstride=14)
}
@ex.named_config
def passt_s_p16_s14_128_ap469():
'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 '
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s14_128_ap469", fstride=14,
tstride=14)
}
@ex.named_config
def passt_s_swa_p16_s12_128_ap473():
'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 '
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s12_128_ap473", fstride=12,
tstride=12)
}
@ex.named_config
def passt_s_p16_s12_128_ap470():
'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 '
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s12_128_ap470", fstride=12,
tstride=12)
}
@ex.named_config
def ensemble_s10():
'use ensemble of PaSST models pretrained on Audioset with S10 mAP=.4864'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s10", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
]
)
}
@ex.named_config
def ensemble_many():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4956'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
("passt_s_p16_s12_128_ap470", 12, 12),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_p16_s14_128_ap469", 14, 14),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
("passt_s_p16_s16_128_ap468", 16, 16),
]
)
}
@ex.named_config
def ensemble_4():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4926'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
)
}
@ex.named_config
def ensemble_5():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.49459'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
)
}
@ex.named_config
def ensemble_s16_14():
'use ensemble of two PaSST models pretrained on Audioset with stride 16 and 14 mAP=.48579'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s16", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
)
}
@ex.named_config
def dynamic_roll():
# dynamically roll the spectrograms/waveforms
# updates the dataset config
basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000)
)
# extra commands
@ex.command
def test_loaders_train_speed():
# test how fast data is being loaded from the data loaders.
itr = ex.datasets.training.get_iter()
import time
start = time.time()
print("hello")
for i, b in enumerate(itr):
if i % 20 == 0:
print(f"{i}/{len(itr)}", end="\r")
end = time.time()
print("totoal time:", end - start)
start = time.time()
print("retry:")
for i, b in enumerate(itr):
if i % 20 == 0:
print(f"{i}/{len(itr)}", end="\r")
end = time.time()
print("totoal time:", end - start)
================================================
FILE: environment.yml
================================================
name: ba3l
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- ca-certificates=2022.10.11
- certifi=2022.12.7
- ld_impl_linux-64=2.38
- libffi=3.4.2
- libgcc-ng=11.2.0
- libgomp=11.2.0
- libstdcxx-ng=11.2.0
- ncurses=6.3
- openssl=1.1.1s
- pip=22.3.1
- python=3.8.16
- readline=8.2
- setuptools=65.6.3
- sqlite=3.40.1
- tk=8.6.12
- wheel=0.37.1
- xz=5.2.10
- zlib=1.2.13
- pip:
- absl-py==1.4.0
- aiohttp==3.8.4
- aiosignal==1.3.1
- appdirs==1.4.4
- async-timeout==4.0.2
- attrs==22.2.0
- audioread==3.0.0
- av==10.0.0
- cachetools==5.3.0
- cffi==1.15.1
- charset-normalizer==3.0.1
- colorama==0.4.6
- decorator==5.1.1
- docopt==0.6.2
- frozenlist==1.3.3
- fsspec==2023.3.0
- future==0.18.3
- gitdb==4.0.10
- gitpython==3.1.31
- google-auth==2.17.0
- google-auth-oauthlib==0.4.6
- grpcio==1.53.0
- h5py==3.8.0
- idna==3.4
- imageio==2.27.0
- importlib-metadata==6.1.0
- joblib==1.2.0
- jsonpickle==3.0.1
- kk-sacred==0.8.4
- lazy-loader==0.2
- librosa==0.10.0.post2
- llvmlite==0.39.1
- markdown==3.4.3
- markupsafe==2.1.2
- msgpack==1.0.5
- multidict==6.0.4
- munch==2.5.0
- numba==0.56.4
- numpy==1.23.5
- nvidia-cublas-cu11==11.10.3.66
- nvidia-cuda-nvrtc-cu11==11.7.99
- nvidia-cuda-runtime-cu11==11.7.99
- nvidia-cudnn-cu11==8.5.0.96
- oauthlib==3.2.2
- packaging==23.0
- pandas==1.5.3
- pillow==9.4.0
- pooch==1.6.0
- protobuf==4.22.1
- py-cpuinfo==9.0.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycparser==2.21
- python-dateutil==2.8.2
- pytorch-lightning==1.3.0.dev0
- pytz==2023.3
- pyyaml==6.0
- requests==2.28.2
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.2.2
- scipy==1.10.1
- six==1.16.0
- smmap==5.0.0
- soundfile==0.12.1
- soxr==0.3.4
- tensorboard==2.12.0
- tensorboard-data-server==0.7.0
- tensorboard-plugin-wit==1.8.1
- test-tube==0.7.5
- threadpoolctl==3.1.0
- timm==0.4.12
- torch==1.13.1
- torchaudio==0.13.1
- torchmetrics==0.2.0
- torchvision==0.14.1
- tqdm==4.65.0
- typing-extensions==4.4.0
- urllib3==1.26.14
- werkzeug==2.2.3
- wrapt==1.15.0
- yarl==1.8.2
- zipp==3.15.0
================================================
FILE: environment_old_2021.yml
================================================
name: ba3l
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=4.5
- _pytorch_select=0.1
- appdirs=1.4.4
- audioread=2.1.9
- blas=1.0
- brotlipy=0.7.0
- bzip2=1.0.8
- c-ares=1.17.1
- ca-certificates=2020.12.5
- cached-property=1.5.2
- cached_property=1.5.2
- certifi=2020.12.5
- cffi=1.14.5
- chardet=4.0.0
- colorama=0.4.4
- cryptography=3.4.6
- cycler=0.10.0
- decorator=4.4.2
- docopt=0.6.2
- ffmpeg=4.3.1
- freetype=2.10.4
- gettext=0.19.8.1
- gitdb=4.0.5
- gitpython=3.1.14
- gmp=6.2.1
- gnutls=3.6.13
- h5py=3.1.0
- hdf5=1.10.6
- idna=2.10
- importlib-metadata=3.7.3
- importlib_metadata=3.7.3
- intel-openmp=2020.2
- joblib=1.0.1
- jpeg=9d
- jsonpickle=1.4.1
- kiwisolver=1.3.1
- krb5=1.17.2
- lame=3.100
- lcms2=2.12
- ld_impl_linux-64=2.35.1
- libblas=3.9.0
- libcblas=3.9.0
- libcurl=7.75.0
- libedit=3.1.20191231
- libev=4.33
- libffi=3.3
- libflac=1.3.3
- libgcc-ng=9.3.0
- libgfortran-ng=9.3.0
- libgfortran5=9.3.0
- libgomp=9.3.0
- liblapack=3.9.0
- libllvm10=10.0.1
- libnghttp2=1.43.0
- libogg=1.3.4
- libopenblas=0.3.12
- libopus=1.3.1
- libpng=1.6.37
- librosa=0.8.0
- libsndfile=1.0.31
- libssh2=1.9.0
- libstdcxx-ng=9.3.0
- libtiff=4.2.0
- libvorbis=1.3.7
- libwebp-base=1.2.0
- llvm-openmp=11.1.0
- llvmlite=0.36.0
- lz4-c=1.9.3
- matplotlib-base=3.3.4
- mkl=2020.2
- mkl-service=2.3.0
- munch=2.5.0
- ncurses=6.2
- nettle=3.6
- ninja=1.10.2
- numba=0.53.0
- numpy=1.20.1
- olefile=0.46
- openblas=0.3.12
- openh264=2.1.1
- openssl=1.1.1k
- packaging=20.9
- pandas=1.2.3
- pillow=8.1.2
- pip=21.0.1
- pooch=1.3.0
- py-cpuinfo=7.0.0
- pycparser=2.20
- pyopenssl=20.0.1
- pyparsing=2.4.7
- pysocks=1.7.1
- pysoundfile=0.10.3.post1
- python=3.7.10
- python-dateutil=2.8.1
- python_abi=3.7
- pytz=2021.1
- readline=8.0
- requests=2.25.1
- resampy=0.2.2
- scikit-learn=0.24.1
- scipy=1.6.1
- setuptools=49.6.0
- six=1.15.0
- smmap=3.0.5
- sqlite=3.34.0
- threadpoolctl=2.1.0
- tk=8.6.10
- tornado=6.1
- typing_extensions=3.7.4.3
- urllib3=1.26.4
- wrapt=1.12.1
- x264=1!161.3030
- xz=5.2.5
- zipp=3.4.1
- zlib=1.2.11
- zstd=1.4.9
- pip:
- absl-py==0.12.0
- aiohttp==3.7.4.post0
- async-timeout==3.0.1
- attrs==20.3.0
- av==8.0.3
- black==20.8b1
- cachetools==4.2.1
- click==7.1.2
- einops==0.3.0
- fsspec==0.8.7
- future==0.18.2
- google-auth==1.28.0
- google-auth-oauthlib==0.4.3
- gpuinfo==1.0.0a7
- grpcio==1.36.1
- imageio==2.9.0
- jedi==0.18.0
- libtmux==0.8.5
- markdown==3.3.4
- multidict==5.1.0
- mypy-extensions==0.4.3
- oauthlib==3.1.0
- parso==0.8.2
- pathspec==0.8.1
- prompt-toolkit==3.0.18
- protobuf==3.15.6
- ptpython==3.0.17
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydub==0.25.1
- pygments==2.8.1
- pymongo==3.11.3
- pyyaml==5.3.1
- regex==2021.4.4
- requests-oauthlib==1.3.0
- rsa==4.7.2
- tensorboard==2.4.1
- tensorboard-plugin-wit==1.8.0
- test-tube==0.7.5
- timm==0.4.12
- toml==0.10.2
- torch==1.8.1+cu111
- torchaudio==0.8.1
- torchmetrics==0.2.0
- torchvision==0.6.0
- tqdm==4.59.0
- typed-ast==1.4.3
- wcwidth==0.2.5
- werkzeug==1.0.1
- wheel==0.36.2
- yarl==1.6.3
================================================
FILE: esc50/README.md
================================================
# Experiments on ESC-50 Environmental Sound Classification
[ESC-50](https://github.com/karolpiczak/ESC-50) consist of 2000 5-second recordings to be classified to 50 semantical classes.
## Setting up the fine-tuning experiments
- Download the prerpocessed (resampled) dataset [esc50.zip](https://github.com/kkoutini/PaSST/releases/download/v.0.0.6/esc50.zip) from the [releases page](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6) and
unpack the zip file into a directory (the default path is `./audioset_hdf5s/`). The `base_dir` config in the dataset file ([here](https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py#L35)) should point to the extracted contents of the dataset zip file.
- Running the experiments using the common configurations (similar to Audioset)
```shell
python3 ex_esc50.py with models.net.s_patchout_t=10 models.net.s_patchout_f=5 basedataset.fold=1 -p
```
## Pre-trained models
Pre-trained models on ESC-50 can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6).
In order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 :
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# model wrapper, includes Melspectrogram and the default transformer
model = get_basic_model(mode="logits")
# replace the transformer with one that outputs 50 classes
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=50)
# load the pre-trained model state dict
state_dict = torch.load('/home/khaled/esc50-passt-s-n-f128-p16-s10-fold1-acc.967.pt')
# load the weights into the transformer
model.net.load_state_dict(state_dict)
# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
# audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
logits=model(audio_wave)
```
================================================
FILE: esc50/dataset.py
================================================
import io
import os
import pathlib
import random
import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler
import torch
from ba3l.ingredients.datasets import Dataset
import pandas as pd
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
from sklearn import preprocessing
from torch.utils.data import Dataset as TorchDataset
import numpy as np
import h5py
from helpers.audiodatasets import PreprocessDataset
LMODE = os.environ.get("LMODE", False)
dataset = Dataset('Esc50')
@dataset.config
def default_config():
name = 'esc50' # dataset name
normalize = False # normalize dataset
subsample = False # subsample squares from the dataset
roll = True # apply roll augmentation
fold = 1
base_dir = "audioset_hdf5s/esc50/" # base directory of the dataset as downloaded
if LMODE:
base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/esc50/"
meta_csv = base_dir + "meta/esc50.csv"
audio_path = base_dir + "audio_32k/"
ir_path = base_dir + "irs/"
num_of_classes = 50
def decode_mp3(mp3_arr):
"""
decodes an array if uint8 representing an mp3 file
:rtype: np.array
"""
container = av.open(io.BytesIO(mp3_arr.tobytes()))
stream = next(s for s in container.streams if s.type == 'audio')
# print(stream)
a = []
for i, packet in enumerate(container.demux(stream)):
for frame in packet.decode():
a.append(frame.to_ndarray().reshape(-1))
waveform = np.concatenate(a)
if waveform.dtype != 'float32':
raise RuntimeError("Unexpected wave type")
return waveform
def pad_or_truncate(x, audio_length):
"""Pad all audio to specific length."""
if len(x) <= audio_length:
return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
else:
return x[0: audio_length]
irs_arr = None
@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
if not ir_augment:
return
global irs_arr
if irs_arr is None:
all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
all_paths = sorted(all_paths)
if cut_irs_offset is not None:
all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
print("will use these IRs:")
for i in range(len(all_paths_name)):
print(i, ": ", all_paths_name[i])
_run.info["ir_devices"] = all_paths_name
irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
return irs_arr[int(np.random.randint(0, len(irs_arr)))]
@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
if ir_augment and torch.rand(1) < ir_augment:
ir = get_ir_sample()
waveform = convolve(waveform, ir, 'full')
if gain_augment:
gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
amp = 10 ** (gain / 20)
waveform = waveform * amp
return waveform
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
if torch.rand(1) < self.rate:
x1, f1, y1 = self.dataset[index]
idx2 = torch.randint(len(self.dataset), (1,)).item()
x2, f2, y2 = self.dataset[idx2]
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1 - x1.mean()
x2 = x2 - x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
return x, f1, (y1 * l + y2 * (1. - l))
return self.dataset[index]
def __len__(self):
return len(self.dataset)
class AudioSetDataset(TorchDataset):
def __init__(self, meta_csv, audiopath, fold, train=False, sample_rate=32000, classes_num=527,
clip_length=5, augment=False):
"""
Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
"""
self.sample_rate = sample_rate
self.meta_csv = meta_csv
self.df = pd.read_csv(meta_csv)
if train: # training all except this
print(f"Dataset training fold {fold} selection out of {len(self.df)}")
self.df = self.df[self.df.fold != fold]
print(f" for training remains {len(self.df)}")
else:
print(f"Dataset testing fold {fold} selection out of {len(self.df)}")
self.df = self.df[self.df.fold == fold]
print(f" for testing remains {len(self.df)}")
self.clip_length = clip_length * sample_rate
self.sr = sample_rate
self.classes_num = classes_num
self.augment = augment
self.audiopath=audiopath
if augment:
print(f"Will agument data from {meta_csv}")
def __len__(self):
return len(self.df)
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
meta: {
'hdf5_path': str,
'index_in_hdf5': int}
Returns:
data_dict: {
'audio_name': str,
'waveform': (clip_samples,),
'target': (classes_num,)}
"""
row = self.df.iloc[index]
#waveform = decode_mp3(np.fromfile(self.audiopath + row.filename, dtype='uint8'))
waveform, _ = librosa.load(self.audiopath + row.filename, sr=self.sr, mono=True)
if self.augment:
waveform = pydub_augment(waveform)
waveform = pad_or_truncate(waveform, self.clip_length)
waveform = self.resample(waveform)
target = row.target
return waveform.reshape(1, -1), row.filename, target
def resample(self, waveform):
"""Resample.
Args:
waveform: (clip_samples,)
Returns:
(resampled_clip_samples,)
"""
if self.sample_rate == 32000:
return waveform
elif self.sample_rate == 16000:
return waveform[0:: 2]
elif self.sample_rate == 8000:
return waveform[0:: 4]
else:
raise Exception('Incorrect sample rate!')
@dataset.command
def get_base_training_set(meta_csv, audio_path, fold=1):
ds = AudioSetDataset(meta_csv, audio_path, fold, train=True, augment=True)
return ds
@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
epoch_len=100000, sampler_replace=False):
num_nodes = int(os.environ.get('num_nodes', 1))
ddp = int(os.environ.get('DDP', 1))
num_nodes = max(ddp, num_nodes)
print("num_nodes= ", num_nodes)
rank = int(os.environ.get('NODE_RANK', 0))
return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
num_samples=epoch_len, replacement=sampler_replace),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
@dataset.command
def get_base_test_set(meta_csv, audio_path, fold=1):
ds = AudioSetDataset(meta_csv, audio_path, fold, train=False)
return ds
@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
print("rolling...")
def roll_func(b):
x, i, y = b
x = torch.as_tensor(x)
sf = shift
if shift is None:
sf = int(np.random.random_integers(-shift_range, shift_range))
global FirstTime
return x.roll(sf, axis), i, y
return roll_func
@dataset.command
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
print("normalized train!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
if roll:
ds = PreprocessDataset(ds, get_roll_func())
if wavmix:
ds = MixupDataset(ds)
return ds
@dataset.command
def get_test_set(normalize):
ds = get_base_test_set()
if normalize:
print("normalized test!")
fill_norms()
ds = PreprocessDataset(ds, norm_func)
return ds
@dataset.command
def print_conf(_config):
print("Config of ", dataset.path, id(dataset))
print(_config)
print()
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset,
num_replicas=None,
rank=None,
shuffle: bool = True):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch == 0:
print(f"\n DistributedSamplerWrapper : {indices[:10]} \n\n")
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
if __name__ == "__main__":
from sacred import Experiment
ex = Experiment("test_dataset", ingredients=[dataset])
@ex.automain
def default_command():
ex.current_run.get_command_function("print_config")()
get_base_training_set()
ds = get_test_set()
print(ds[0])
ds = get_training_set()
print(ds[0])
print("get_base_training_set", len(get_base_training_set()))
print("get_base_test_set", len(get_base_test_set()))
print("get_training_set", len(get_training_set()))
print("get_test_set", len(get_test_set()))
================================================
FILE: ex_audioset.py
================================================
import os
import sys
import PIL
import pytorch_lightning
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
import wandb
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
ex = Experiment("audioset")
# Example call with all the default config:
# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# with 2 gpus:
# DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
# now you can use in the cmd trainer.precision=16 for example
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
wandb_logger = None
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_full_training_set"),
sampler=CMD("/basedataset.get_ft_weighted_sampler"))
get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
dataset=CMD("/basedataset.get_test_set"))
@ex.config
def default_conf():
cmd = " ".join(sys.argv) # command line arguments
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_deit_bd_p16_384", n_classes=527, s_patchout_t=40,
s_patchout_f=4), # network config
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
basedataset = DynamicIngredient("audioset.dataset.dataset", wavmix=1)
wandb = dict(project="passt_audioset2", log_model=True)
watch_model = True
trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, precision=16,
reload_dataloaders_every_epoch=True)
lr = 0.00002 # learning rate
use_mixup = True
mixup_alpha = 0.3
compile = True # compile the model, requires pytorch >= 2.0
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(
f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.distributed_mode = self.config.trainer.num_nodes > 1
if self.config.compile:
# pt 2 magic
print("\n\nCompiling the model pytorch 2... \n\n")
self.net = torch.compile(self.net)
# compile only the net, not the mel
#self.mel = torch.compile(self.mel)
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
if self.global_step < 5:
images = [ wandb.Image(
PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"),
caption="spectrograms",
) for i in x[:, 0, :, :].cpu().numpy()]
wandb.log({"spectrograms": images})
# wandb_logger.log_image(key="spectrograms", images=[i for i in x[:,0,:,:].cpu().numpy()])
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + \
x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + \
y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
self.log('trainer/lr', self.trainer.optimizers[0].param_groups[0]['lr'])
self.log('epoch', self.current_epoch)
self.log("training.loss", loss.detach())
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, # 'step': self.current_epoch
}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
if self.global_step < 5:
images = [ wandb.Image(
PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"),
caption="validation_spectrograms",
) for i in x[:, 0, :, :].cpu().numpy()]
wandb.log({"validation_spectrograms": images})
# wandb_logger.log_image(key="validation_spectrograms", images=[
# i for i in x[:, 0, :, :].cpu().numpy()])
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss,
net_name + "out": out, net_name + "target": y.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss']
for x in outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
target = torch.cat([x[net_name + 'target']
for x in outputs], dim=0)
try:
average_precision = metrics.average_precision_score(
target.float().numpy(), out.float().numpy(), average=None)
except ValueError:
average_precision = np.array([np.nan] * 527)
try:
roc = metrics.roc_auc_score(
target.numpy(), out.numpy(), average=None)
except ValueError:
roc = np.array([np.nan] * 527)
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),}
#'step': torch.as_tensor(self.current_epoch).cuda()}
# torch.save(average_precision,
# f"ap_perclass_{average_precision.mean()}.pt")
# print(average_precision)
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
alltarget = self.all_gather(target)
average_precision = metrics.average_precision_score(
alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),
allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)
if self.trainer.is_global_zero:
logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
# 'step': torch.as_tensor(self.current_epoch).cuda()
}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict(
{net_name + "allap": logs[net_name + 'ap'],
#'step': logs['step']
}
, sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback() + [LearningRateMonitor(logging_interval='epoch')]
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
if save_last_n is None:
return []
return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=50,
swa_freq=5):
if not swa:
return []
print("\n Using swa!\n")
if pytorch_lightning.__version__ <= "1.5":
from helpers.swa_legacy import StochasticWeightAveraging
else:
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed, watch_model=True):
global wandb_logger
wandb_logger = get_logger()
trainer = get_trainer(logger=wandb_logger)
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
if watch_model:
# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(modul, log_freq=1000, log="all" )
if pytorch_lightning.__version__ <= "1.5":
trainer.fit(
modul,
train_dataloader=train_loader,
val_dataloaders=val_loader,
)
else:
trainer.fit(
modul,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=12):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
net = torch.compile(net)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(
y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(
y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) /
(t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer(logger=get_logger())
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
'''
read the dataset sequentially, useful if you have a network cache
@param all_y: the dataset preload command
@return:
'''
print(all_y.shape)
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[
rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + \
[f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
# plz no collisions
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}"
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: ex_esc50.py
================================================
import os
import sys
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
ex = Experiment("passt_esc50")
# Example call with all the default config:
# python ex_esc50.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base"
# with 2 gpus:
# DDP=2 python ex_esc50.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base"
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"),
)
get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
dataset=CMD("/basedataset.get_test_set"))
@ex.config
def default_conf():
cmd = " ".join(sys.argv)
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", n_classes=50, s_patchout_t=10, s_patchout_f=3),
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
timem=80,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
basedataset = DynamicIngredient("esc50.dataset.dataset")
trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True)
wandb = dict(project="passt_esc50", log_model=True)
lr = 0.00001
use_mixup = True
mixup_alpha = 0.3
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.distributed_mode = self.config.trainer.num_nodes > 1
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
# y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = (F.cross_entropy(y_hat, y, reduction="none") * lam.reshape(batch_size) +
F.cross_entropy(y_hat, y[rn_indices], reduction="none") * (1. - lam.reshape(batch_size)))
loss = samples_loss.mean()
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.cross_entropy(y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, 'step': self.current_epoch}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.cross_entropy(y_hat, y, reduction="none")
loss = samples_loss.mean()
_, preds = torch.max(y_hat, dim=1)
n_correct_pred_per_sample = (preds == y)
n_correct_pred = n_correct_pred_per_sample.sum()
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss,
net_name + "n_correct_pred": torch.as_tensor(n_correct_pred), net_name + "n_pred":torch.as_tensor( len(y)) }
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'acc': torch.as_tensor(val_acc).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback()
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
if save_last_n is None:
return []
return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=2,
swa_freq=1):
if not swa:
return []
print("\n Using swa!\n")
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed):
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
trainer.fit(
modul,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
)
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: ex_fsd50k.py
================================================
import os
import sys
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
ex = Experiment("fsd50k")
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
# Example call with all the default config:
# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# with 2 gpus:
# DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=True, dataset=CMD("/basedataset.get_training_set"),
)
get_validate_loader = ex.datasets.valid.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=10, num_workers=16,
dataset=CMD("/basedataset.get_valid_set"))
get_eval_loader = ex.datasets.eval.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=10, num_workers=16,
dataset=CMD("/basedataset.get_eval_set"))
@ex.named_config
def variable_eval():
basedataset = dict(variable_eval=True)
datasets = dict(valid=dict(batch_size=1), eval=dict(batch_size=1))
@ex.config
def default_conf():
cmd = " ".join(sys.argv) # command line arguments
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing",
n_classes=200, s_patchout_t=10, s_patchout_f=4), # network config
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=0,
timem=0,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
# set the default name for wandb logger
wandb = dict(project="passt_fsd50k", log_model=True)
basedataset = DynamicIngredient("fsd50k.dataset.dataset", wavmix=1)
trainer = dict(max_epochs=50, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True)
lr = 0.00001 # learning rate
use_mixup = True
mixup_alpha = 0.3
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_len=10, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.valid_names = ["valid", "eval"]
self.distributed_mode = self.config.trainer.num_nodes > 1
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, 'step': self.current_epoch}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx, dataloader_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
for idx, one_outputs in enumerate(outputs):
set_name = self.valid_names[idx] + "_"
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in one_outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in one_outputs], dim=0)
target = torch.cat([x[net_name + 'target'] for x in one_outputs], dim=0)
try:
average_precision = metrics.average_precision_score(
target.float().numpy(), out.float().numpy(), average=None)
except ValueError:
average_precision = np.array([np.nan] * 200)
try:
roc = metrics.roc_auc_score(target.numpy(), out.numpy(), average=None)
except ValueError:
roc = np.array([np.nan] * 200)
logs = {set_name + net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
set_name + net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
set_name + net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
# torch.save(average_precision, f"ap_perclass_{average_precision.mean()}.pt")
# print(average_precision)
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
alltarget = self.all_gather(target)
average_precision = metrics.average_precision_score(
alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),
allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)
if self.trainer.is_global_zero:
logs = {set_name + net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict(
{set_name + net_name + "allap": logs[set_name + net_name + 'ap'], 'step': logs['step']},
sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback()
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn_norm:
return None, None
raise RuntimeError('no dynamic norm supported yet.')
model_checkpoint_callback = None
@ex.command
def get_extra_checkpoint_callback(save_best=None):
if save_best is None:
return []
global model_checkpoint_callback
model_checkpoint_callback = ModelCheckpoint(monitor="allap", verbose=True, save_top_k=save_best, mode='max',
every_n_val_epochs=1, every_n_train_steps=0)
return [model_checkpoint_callback]
@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=10,
swa_freq=3):
if not swa:
return []
print("\n Using swa!\n")
from helpers.swa_callback import StochasticWeightAveraging
return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]
@ex.command
def main(_run, _config, _log, _rnd, _seed):
trainer = get_trainer(logger=get_logger())
train_loader = get_train_loader()
val_loader = get_validate_loader()
eval_loader = get_eval_loader()
# eval_loader = get_eval_loader()
modul = M(ex)
trainer.fit(
modul,
train_dataloader=train_loader,
val_dataloaders=[val_loader, eval_loader],
)
## evaluate best model on eval set
#trainer.val_dataloaders = None
modul.val_dataloaders = None
return {"done": True}
@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
'''
Test training speed of a model
@param _run:
@param _config:
@param _log:
@param _rnd:
@param _seed:
@param speed_test_batch_size: the batch size during the test
@return:
'''
modul = M(ex)
modul = modul.cuda()
batch_size = speed_test_batch_size
print(f"\nBATCH SIZE : {batch_size}\n")
test_length = 100
print(f"\ntest_length : {test_length}\n")
x = torch.ones([batch_size, 1, 128, 998]).cuda()
target = torch.ones([batch_size, 527]).cuda()
# one passe
net = modul.net
# net(x)
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
# net = torch.jit.trace(net,(x,))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001)
print("warmup")
import time
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('warmup done:', (t2 - t1))
torch.cuda.synchronize()
t1 = time.time()
print("testing speed")
for i in range(test_length):
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")
@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer()
train_loader = get_train_loader()
val_loader = get_validate_loader()
modul = M(ex)
modul.val_dataloader = None
#trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, dataloaders=val_loader)
print("\n\n Validtaion:")
print(res)
@ex.command
def test_loaders():
'''
get one sample from each loader for debbuging
@return:
'''
for i, b in enumerate(ex.datasets.training.get_iter()):
print(b)
break
for i, b in enumerate(ex.datasets.test.get_iter()):
print(b)
break
def set_default_json_pickle(obj):
if isinstance(obj, set):
return list(obj)
raise TypeError
@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
'''
read the dataset sequentially, useful if you have a network cache
@param all_y: the dataset preload command
@return:
'''
print(all_y.shape)
def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]
argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)
@ex.main
def default_command():
return main()
ex.run_commandline(argv)
if __name__ == '__main__':
# set DDP=2 forks two processes to run on two GPUs
# the environment variable "DDP" define the number of processes to fork
# With two 2x 2080ti you can train the full model to .47 in around 24 hours
# you may need to set NCCL_P2P_DISABLE=1
word_size = os.environ.get("DDP", None)
if word_size:
import random
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'
for rank in range(word_size):
pid = os.fork()
if pid == 0:
print("Child Forked ")
multiprocessing_run(rank, word_size)
exit(0)
pid, exit_code = os.wait()
print(pid, exit_code)
exit(0)
print("__main__ is running pid", os.getpid(), "in module main: ", __name__)
@ex.automain
def default_command():
return main()
================================================
FILE: ex_openmic.py
================================================
import os
import sys
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule
from torch.utils.data import DataLoader
from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
ex = Experiment("openmic")
# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
# Example call with all the default config:
# python ex_openmic.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base"
# with 2 gpus:
# DDP=2 python ex_openmic.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base"
# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=6,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"),
)
get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
dataset=CMD("/basedataset.get_test_set"))
@ex.config
def default_conf():
cmd = " ".join(sys.argv)
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
if os.environ.get("SLURM_ARRAY_JOB_ID", False):
slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
"").strip()
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", n_classes=20, s_patchout_t=40, s_patchout_f=4),
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=2000)
}
wandb = dict(project="passt_openmic", log_model=True)
basedataset = DynamicIngredient("openmic.dataset.dataset", wavmix=1)
# set the default for the trainer
trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True)
lr = 0.00001
use_mixup = True
mixup_alpha = 0.3
# register extra possible configs
add_configs(ex)
@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
schedule_mode="exp_lin"):
if schedule_mode == "exp_lin":
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
if schedule_mode in {"exp_lin", "cos_cyc"}:
return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
return torch.optim.Adam(params, lr=lr)
class M(Ba3lModule):
def __init__(self, experiment):
self.mel = None
self.da_net = None
super(M, self).__init__(experiment)
self.use_mixup = self.config.use_mixup or False
self.mixup_alpha = self.config.mixup_alpha
desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
self.experiment.info["start_sum_params"] = sum_params
self.experiment.info["start_sum_params_non_zero"] = sum_non_zero
# in case we need embedings for the DA
self.net.return_embed = True
self.dyn_norm = self.config.dyn_norm
self.do_swa = False
self.distributed_mode = self.config.trainer.num_nodes > 1
def forward(self, x):
return self.net(x)
def mel_forward(self, x):
old_shape = x.size()
x = x.reshape(-1, old_shape[2])
x = self.mel(x)
x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
if self.dyn_norm:
if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
tr_m, tr_std = get_dynamic_norm(self)
self.register_buffer('tr_m', tr_m)
self.register_buffer('tr_std', tr_std)
x = (x - self.tr_m) / self.tr_std
return x
def training_step(self, batch, batch_idx):
# REQUIRED
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_mask = y[:, 20:]
y = y[:, :20] > 0.5
y = y.float()
orig_x = x
batch_size = len(y)
rn_indices, lam = None, None
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
y_hat, embed = self.forward(x)
if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
y_mix_mask = ((y_mask > 0.5) | (y_mask[rn_indices] > 0.5)).float()
samples_loss = y_mask.float() * samples_loss
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
samples_loss = y_mask.float() * samples_loss
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
results = {"loss": loss, }
return results
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
logs = {'train.loss': avg_loss, 'step': self.current_epoch}
self.log_dict(logs, sync_dist=True)
def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_hat, _ = self.forward(x)
return f, y_hat
def validation_step(self, batch, batch_idx):
x, f, y = batch
if self.mel:
x = self.mel_forward(x)
y_mask = y[:, 20:]
y = y[:, :20] > 0.5
y = y.float()
results = {}
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
y_hat, _ = net(x)
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
samples_loss = y_mask.float() * samples_loss
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach(),
net_name + "mask": y_mask.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results
def validation_epoch_end(self, outputs):
model_name = [("", self.net)]
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0)
mask = torch.cat([x[net_name + 'mask'] for x in outputs], dim=0)
try:
y_true = target.float().numpy()
y_pred = out.float().numpy()
y_mask = mask.float().numpy()
average_precision = np.array([metrics.average_precision_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
except ValueError:
average_precision = np.array([np.nan] * y_true.shape[1])
#torch.save(average_precision, f"ap_openmic_perclass_{average_precision.mean()}.pt")
try:
roc = np.array([metrics.roc_auc_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
except ValueError:
roc = np.array([np.nan] * y_true.shape[1])
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
alltarget = self.all_gather(target)
all_mask = self.all_gather(mask)
y_true = alltarget.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
y_pred = allout.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
y_mask = all_mask.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
average_precision = np.array([metrics.average_precision_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
if self.trainer.is_global_zero:
logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict({net_name + "allap": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True)
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
# (LBFGS it is automatically supported, no need for closure function)
optimizer = get_optimizer(self.parameters())
# torch.optim.Adam(self.parameters(), lr=self.config.lr)
return {
'optimizer': optimizer,
'lr_scheduler': get_lr_scheduler(optimizer)
}
def configure_callbacks(self):
return get_extra_checkpoint_callback() + get_extra_swa_callback()
@ex.command
def get_dynamic_norm(model, dyn_norm=False):
if not dyn
gitextract_blhd6h9_/ ├── .github/ │ └── workflows/ │ └── pylint.yml ├── .gitignore ├── .markdownlint.json ├── LICENSE ├── README.md ├── __init__.py ├── audioset/ │ ├── README.md │ ├── dataset.py │ └── prepare_scripts/ │ ├── convert_to_mp3.py │ └── create_h5pymp3_dataset.py ├── ba3l/ │ ├── __init__.py │ ├── experiment.py │ ├── ingredients/ │ │ ├── __init__.py │ │ ├── datasets.py │ │ ├── ingredient.py │ │ ├── models.py │ │ └── trainer.py │ ├── module.py │ ├── plutils/ │ │ ├── __init__.py │ │ ├── lr_monitor.py │ │ └── progress_bar.py │ └── util/ │ ├── __init__.py │ ├── functions.py │ └── sacred_logger.py ├── config_updates.py ├── environment.yml ├── environment_old_2021.yml ├── esc50/ │ ├── README.md │ └── dataset.py ├── ex_audioset.py ├── ex_esc50.py ├── ex_fsd50k.py ├── ex_openmic.py ├── fsd50k/ │ ├── README.md │ ├── dataset.py │ └── prepare_scripts/ │ ├── convert_to_mp3.py │ └── create_h5pymp3_dataset.py ├── helpers/ │ ├── audiodatasets.py │ ├── mixup.py │ ├── models_size.py │ ├── ramp.py │ ├── swa_callback.py │ ├── swa_legacy.py │ └── workersinit.py ├── models/ │ ├── helpers/ │ │ └── vit_helpers.py │ ├── passt.py │ └── preprocess.py ├── openmic/ │ ├── README.md │ ├── dataset.py │ └── prepare_scripts/ │ └── download_preprocess.py ├── pip_list.txt └── requirements.txt
SYMBOL INDEX (461 symbols across 34 files)
FILE: audioset/dataset.py
function default_config (line 26) | def default_config():
function LMODE_default_config (line 48) | def LMODE_default_config():
function decode_mp3 (line 55) | def decode_mp3(mp3_arr):
function pad_or_truncate (line 73) | def pad_or_truncate(x, audio_length):
function get_ir_sample (line 85) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
function pydub_augment (line 104) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
class MixupDataset (line 115) | class MixupDataset(TorchDataset):
method __init__ (line 119) | def __init__(self, dataset, beta=2, rate=0.5):
method __getitem__ (line 125) | def __getitem__(self, index):
method __len__ (line 139) | def __len__(self):
class AudioSetDataset (line 143) | class AudioSetDataset(TorchDataset):
method __init__ (line 144) | def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip...
method open_hdf5 (line 164) | def open_hdf5(self):
method __len__ (line 167) | def __len__(self):
method __del__ (line 170) | def __del__(self):
method __getitem__ (line 175) | def __getitem__(self, index):
method resample (line 202) | def resample(self, waveform):
function get_base_training_set (line 220) | def get_base_training_set(balanced_train_hdf5):
function get_unbalanced_training_set (line 226) | def get_unbalanced_training_set(unbalanced_train_hdf5):
function get_norms_dataset (line 233) | def get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5):
function get_base_full_training_set (line 240) | def get_base_full_training_set():
function preload_mp3 (line 247) | def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_class...
function get_ft_cls_balanced_sample_weights (line 258) | def get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_t...
function get_ft_weighted_sampler (line 294) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
function get_base_test_set (line 310) | def get_base_test_set(eval_hdf5):
function get_roll_func (line 316) | def get_roll_func(axis=1, shift=None, shift_range=50):
function get_training_set (line 333) | def get_training_set(normalize, roll, wavmix=False):
function get_full_training_set (line 349) | def get_full_training_set(normalize, roll, wavmix=False):
function get_test_set (line 365) | def get_test_set(normalize):
function print_conf (line 375) | def print_conf(_config):
class DistributedSamplerWrapper (line 381) | class DistributedSamplerWrapper(DistributedSampler):
method __init__ (line 382) | def __init__(
method __iter__ (line 392) | def __iter__(self):
function default_command (line 410) | def default_command():
FILE: audioset/prepare_scripts/convert_to_mp3.py
function process_folder (line 15) | def process_folder(fol="balanced_train_segments"):
function process_one (line 28) | def process_one(i, f1, f2):
FILE: audioset/prepare_scripts/create_h5pymp3_dataset.py
function read_metadata (line 18) | def read_metadata(csv_path, classes_num, id_to_ix):
function check_available (line 75) | def check_available(balanced_csv,balanced_audio_path,prefix=None):
FILE: ba3l/experiment.py
function ingredients_recursive_apply (line 19) | def ingredients_recursive_apply(ing, fn):
function config_recursive_apply (line 24) | def config_recursive_apply(conf, fn):
function get_loggers (line 32) | def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):
class Experiment (line 43) | class Experiment(Sacred_Experiment):
method __init__ (line 48) | def __init__(
method get_run_identifier (line 116) | def get_run_identifier(self):
method get_dataloaders (line 121) | def get_dataloaders(self, filter={}):
method get_train_dataloaders (line 130) | def get_train_dataloaders(self):
method get_val_dataloaders (line 133) | def get_val_dataloaders(self):
method _create_run (line 136) | def _create_run(
method command (line 169) | def command(
method add_default_args_config (line 205) | def add_default_args_config(self, function, prefix, extra_args={}, sta...
FILE: ba3l/ingredients/datasets.py
function raise_ (line 14) | def raise_(ex):
class Dataset (line 18) | class Dataset(Ingredient):
method __init__ (line 29) | def __init__(
method dataset (line 99) | def dataset(
method iter (line 135) | def iter(
method __getattr__ (line 174) | def __getattr__(self, k):
class Datasets (line 183) | class Datasets(Ingredient, Munch):
method get_instance (line 193) | def get_instance(cls):
method __init__ (line 198) | def __init__(
method __getattr__ (line 261) | def __getattr__(self, k):
method __setattr__ (line 268) | def __setattr__(self, k, v):
method __getitem__ (line 280) | def __getitem__(self, k):
method __hash__ (line 294) | def __hash__(self):
method get_datasets (line 297) | def get_datasets(self, config_conditions={}, return_datasets_names=Fal...
FILE: ba3l/ingredients/ingredient.py
function raise_ (line 14) | def raise_(ex):
class Ingredient (line 18) | class Ingredient(sacred_Ingredient):
method __init__ (line 26) | def __init__(
method add_default_args_config (line 84) | def add_default_args_config(self, function, prefix, extra_args={}, sta...
method command (line 102) | def command(
FILE: ba3l/ingredients/models.py
class Models (line 24) | class Models(Ingredient):
method get_instance (line 34) | def get_instance(cls):
method __init__ (line 39) | def __init__(
function raise_ (line 98) | def raise_(ex):
class Model (line 102) | class Model(Ingredient):
method __init__ (line 112) | def __init__(
method instance (line 176) | def instance(
method __getattr__ (line 210) | def __getattr__(self, k):
FILE: ba3l/ingredients/trainer.py
class Trainer (line 14) | class Trainer(Ingredient):
FILE: ba3l/module.py
class Ba3lModule (line 33) | class Ba3lModule(pl.LightningModule):
method __init__ (line 34) | def __init__(self, experiment):
FILE: ba3l/plutils/lr_monitor.py
class LearningRateMonitor (line 30) | class LearningRateMonitor(Callback):
method __init__ (line 70) | def __init__(self, logging_interval: Optional[str] = None, log_momentu...
method on_train_start (line 79) | def on_train_start(self, trainer, *args, **kwargs):
method on_train_batch_start (line 119) | def on_train_batch_start(self, trainer, *args, **kwargs):
method on_train_epoch_start (line 130) | def on_train_epoch_start(self, trainer, *args, **kwargs):
method _extract_stats (line 138) | def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
method _extract_lr (line 158) | def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
method _extract_momentum (line 163) | def _extract_momentum(self, param_group, name: str, use_betas: bool) -...
method _find_names (line 171) | def _find_names(self, lr_schedulers) -> List[str]:
method _should_log (line 204) | def _should_log(trainer) -> bool:
FILE: ba3l/plutils/progress_bar.py
class ProgressBar (line 17) | class ProgressBar(PlProgressBar):
method init_validation_tqdm (line 66) | def init_validation_tqdm(self) -> tqdm:
method on_epoch_start (line 83) | def on_epoch_start(self, trainer, pl_module):
method on_train_batch_end (line 97) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_start (line 103) | def on_validation_start(self, trainer, pl_module):
method on_validation_batch_end (line 113) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
method on_validation_end (line 119) | def on_validation_end(self, trainer, pl_module):
function convert_inf (line 127) | def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, f...
function reset (line 134) | def reset(bar: tqdm, total: Optional[int] = None) -> None:
FILE: ba3l/util/functions.py
function get_default_kwargs_dict (line 5) | def get_default_kwargs_dict(f):
FILE: ba3l/util/sacred_logger.py
class SacredLogger (line 23) | class SacredLogger(LightningLoggerBase):
method __init__ (line 24) | def __init__(self, sacred_experiment):
method experiment (line 37) | def experiment(self):
method run_id (line 41) | def run_id(self):
method log_hyperparams (line 48) | def log_hyperparams(self, params):
method log_metrics (line 53) | def log_metrics(self, metrics, step=None):
method name (line 61) | def name(self):
method version (line 65) | def version(self):
method save (line 69) | def save(self):
method finalize (line 76) | def finalize(self, status):
FILE: config_updates.py
function add_configs (line 4) | def add_configs(ex):
FILE: esc50/dataset.py
function default_config (line 29) | def default_config():
function decode_mp3 (line 47) | def decode_mp3(mp3_arr):
function pad_or_truncate (line 65) | def pad_or_truncate(x, audio_length):
function get_ir_sample (line 77) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
function pydub_augment (line 96) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
class MixupDataset (line 107) | class MixupDataset(TorchDataset):
method __init__ (line 111) | def __init__(self, dataset, beta=2, rate=0.5):
method __getitem__ (line 117) | def __getitem__(self, index):
method __len__ (line 131) | def __len__(self):
class AudioSetDataset (line 137) | class AudioSetDataset(TorchDataset):
method __init__ (line 138) | def __init__(self, meta_csv, audiopath, fold, train=False, sample_rat...
method __len__ (line 163) | def __len__(self):
method __getitem__ (line 166) | def __getitem__(self, index):
method resample (line 190) | def resample(self, waveform):
function get_base_training_set (line 209) | def get_base_training_set(meta_csv, audio_path, fold=1):
function get_ft_weighted_sampler (line 215) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
function get_base_test_set (line 231) | def get_base_test_set(meta_csv, audio_path, fold=1):
function get_roll_func (line 238) | def get_roll_func(axis=1, shift=None, shift_range=50):
function get_training_set (line 255) | def get_training_set(normalize, roll, wavmix=False):
function get_test_set (line 271) | def get_test_set(normalize):
function print_conf (line 281) | def print_conf(_config):
class DistributedSamplerWrapper (line 287) | class DistributedSamplerWrapper(DistributedSampler):
method __init__ (line 288) | def __init__(
method __iter__ (line 298) | def __iter__(self):
function default_command (line 316) | def default_command():
FILE: ex_audioset.py
function default_conf (line 52) | def default_conf():
function get_scheduler_lambda (line 87) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_le...
function get_lr_scheduler (line 98) | def get_lr_scheduler(optimizer, schedule_mode):
function get_optimizer (line 105) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
class M (line 112) | class M(Ba3lModule):
method __init__ (line 113) | def __init__(self, experiment):
method forward (line 139) | def forward(self, x):
method mel_forward (line 142) | def mel_forward(self, x):
method training_step (line 155) | def training_step(self, batch, batch_idx):
method training_epoch_end (line 200) | def training_epoch_end(self, outputs):
method predict (line 208) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
method validation_step (line 216) | def validation_step(self, batch, batch_idx):
method validation_epoch_end (line 245) | def validation_epoch_end(self, outputs):
method configure_optimizers (line 293) | def configure_optimizers(self):
method configure_callbacks (line 304) | def configure_callbacks(self):
function get_dynamic_norm (line 309) | def get_dynamic_norm(model, dyn_norm=False):
function get_extra_checkpoint_callback (line 316) | def get_extra_checkpoint_callback(save_last_n=None):
function get_extra_swa_callback (line 323) | def get_extra_swa_callback(swa=True, swa_epoch_start=50,
function main (line 336) | def main(_run, _config, _log, _rnd, _seed, watch_model=True):
function model_speed_test (line 365) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
function evaluate_only (line 430) | def evaluate_only(_run, _config, _log, _rnd, _seed):
function test_loaders (line 445) | def test_loaders():
function set_default_json_pickle (line 459) | def set_default_json_pickle(obj):
function preload_mp3 (line 466) | def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
function multiprocessing_run (line 475) | def multiprocessing_run(rank, word_size):
function default_command (line 530) | def default_command():
FILE: ex_esc50.py
function default_conf (line 50) | def default_conf():
function get_scheduler_lambda (line 82) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_le...
function get_lr_scheduler (line 92) | def get_lr_scheduler(optimizer, schedule_mode):
function get_optimizer (line 99) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
class M (line 106) | class M(Ba3lModule):
method __init__ (line 107) | def __init__(self, experiment):
method forward (line 126) | def forward(self, x):
method mel_forward (line 129) | def mel_forward(self, x):
method training_step (line 142) | def training_step(self, batch, batch_idx):
method training_epoch_end (line 175) | def training_epoch_end(self, outputs):
method predict (line 182) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
method validation_step (line 190) | def validation_step(self, batch, batch_idx):
method validation_epoch_end (line 212) | def validation_epoch_end(self, outputs):
method configure_optimizers (line 224) | def configure_optimizers(self):
method configure_callbacks (line 235) | def configure_callbacks(self):
function get_dynamic_norm (line 240) | def get_dynamic_norm(model, dyn_norm=False):
function get_extra_checkpoint_callback (line 247) | def get_extra_checkpoint_callback(save_last_n=None):
function get_extra_swa_callback (line 254) | def get_extra_swa_callback(swa=True, swa_epoch_start=2,
function main (line 264) | def main(_run, _config, _log, _rnd, _seed):
function model_speed_test (line 281) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
function evaluate_only (line 342) | def evaluate_only(_run, _config, _log, _rnd, _seed):
function test_loaders (line 358) | def test_loaders():
function set_default_json_pickle (line 372) | def set_default_json_pickle(obj):
function multiprocessing_run (line 379) | def multiprocessing_run(rank, word_size):
function default_command (line 431) | def default_command():
FILE: ex_fsd50k.py
function variable_eval (line 54) | def variable_eval():
function default_conf (line 60) | def default_conf():
function get_scheduler_lambda (line 94) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_le...
function get_lr_scheduler (line 104) | def get_lr_scheduler(optimizer, schedule_mode):
function get_optimizer (line 111) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
class M (line 118) | class M(Ba3lModule):
method __init__ (line 119) | def __init__(self, experiment):
method forward (line 138) | def forward(self, x):
method mel_forward (line 141) | def mel_forward(self, x):
method training_step (line 154) | def training_step(self, batch, batch_idx):
method training_epoch_end (line 186) | def training_epoch_end(self, outputs):
method predict (line 193) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
method validation_step (line 201) | def validation_step(self, batch, batch_idx, dataloader_idx):
method validation_epoch_end (line 220) | def validation_epoch_end(self, outputs):
method configure_optimizers (line 262) | def configure_optimizers(self):
method configure_callbacks (line 273) | def configure_callbacks(self):
function get_dynamic_norm (line 278) | def get_dynamic_norm(model, dyn_norm=False):
function get_extra_checkpoint_callback (line 288) | def get_extra_checkpoint_callback(save_best=None):
function get_extra_swa_callback (line 298) | def get_extra_swa_callback(swa=True, swa_epoch_start=10,
function main (line 308) | def main(_run, _config, _log, _rnd, _seed):
function model_speed_test (line 331) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
function evaluate_only (line 392) | def evaluate_only(_run, _config, _log, _rnd, _seed):
function test_loaders (line 407) | def test_loaders():
function set_default_json_pickle (line 421) | def set_default_json_pickle(obj):
function preload_mp3 (line 428) | def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
function multiprocessing_run (line 437) | def multiprocessing_run(rank, word_size):
function default_command (line 489) | def default_command():
FILE: ex_openmic.py
function default_conf (line 54) | def default_conf():
function get_scheduler_lambda (line 88) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_le...
function get_lr_scheduler (line 98) | def get_lr_scheduler(optimizer, schedule_mode):
function get_optimizer (line 105) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
class M (line 112) | class M(Ba3lModule):
method __init__ (line 113) | def __init__(self, experiment):
method forward (line 135) | def forward(self, x):
method mel_forward (line 138) | def mel_forward(self, x):
method training_step (line 151) | def training_step(self, batch, batch_idx):
method training_epoch_end (line 190) | def training_epoch_end(self, outputs):
method predict (line 197) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
method validation_step (line 205) | def validation_step(self, batch, batch_idx):
method validation_epoch_end (line 228) | def validation_epoch_end(self, outputs):
method configure_optimizers (line 272) | def configure_optimizers(self):
method configure_callbacks (line 283) | def configure_callbacks(self):
function get_dynamic_norm (line 288) | def get_dynamic_norm(model, dyn_norm=False):
function get_extra_checkpoint_callback (line 295) | def get_extra_checkpoint_callback(save_last_n=None):
function get_extra_swa_callback (line 302) | def get_extra_swa_callback(swa=True, swa_epoch_start=2,
function main (line 312) | def main(_run, _config, _log, _rnd, _seed):
function model_speed_test (line 330) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
function evaluate_only (line 391) | def evaluate_only(_run, _config, _log, _rnd, _seed):
function test_loaders (line 406) | def test_loaders():
function set_default_json_pickle (line 420) | def set_default_json_pickle(obj):
function preload_mp3 (line 427) | def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
function multiprocessing_run (line 436) | def multiprocessing_run(rank, word_size):
function default_command (line 488) | def default_command():
FILE: fsd50k/dataset.py
function default_config (line 25) | def default_config():
function LMODE_default_config (line 48) | def LMODE_default_config():
function decode_mp3 (line 52) | def decode_mp3(mp3_arr):
function pad_or_truncate (line 70) | def pad_or_truncate(x, audio_length):
function get_ir_sample (line 86) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
function pydub_augment (line 105) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
class MixupDataset (line 116) | class MixupDataset(TorchDataset):
method __init__ (line 120) | def __init__(self, dataset, beta=2, rate=0.5):
method __getitem__ (line 126) | def __getitem__(self, index):
method __len__ (line 140) | def __len__(self):
class AudioSetDataset (line 144) | class AudioSetDataset(TorchDataset):
method __init__ (line 145) | def __init__(self, hdf5_file, sample_rate=32000, classes_num=200, clip...
method open_hdf5 (line 167) | def open_hdf5(self):
method __len__ (line 170) | def __len__(self):
method __del__ (line 173) | def __del__(self):
method __getitem__ (line 178) | def __getitem__(self, index):
method resample (line 205) | def resample(self, waveform):
function get_base_training_set (line 223) | def get_base_training_set(balanced_train_hdf5, clip_length=10):
function preload_mp3 (line 229) | def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_class...
function get_ft_weighted_sampler (line 242) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
function get_base_eval_set (line 258) | def get_base_eval_set(eval_hdf5, variable_eval=None):
function get_base_valid_set (line 268) | def get_base_valid_set(valid_hdf5, variable_eval=None):
function get_roll_func (line 278) | def get_roll_func(axis=1, shift=None, shift_range=50):
function get_training_set (line 295) | def get_training_set(normalize, roll, wavmix=False):
function get_valid_set (line 311) | def get_valid_set(normalize):
function get_eval_set (line 321) | def get_eval_set(normalize):
function print_conf (line 331) | def print_conf(_config):
class DistributedSamplerWrapper (line 337) | class DistributedSamplerWrapper(DistributedSampler):
method __init__ (line 338) | def __init__(
method __iter__ (line 348) | def __iter__(self):
function default_command (line 366) | def default_command():
FILE: fsd50k/prepare_scripts/convert_to_mp3.py
function process_folder (line 19) | def process_folder(fol="balanced_train_segments"):
function process_one (line 32) | def process_one(i, f1, f2):
FILE: fsd50k/prepare_scripts/create_h5pymp3_dataset.py
function get_labels (line 49) | def get_labels(df):
FILE: helpers/audiodatasets.py
function h6 (line 12) | def h6(w):
class AudioPreprocessDataset (line 15) | class AudioPreprocessDataset(Dataset):
method __init__ (line 23) | def __init__(self, files, labels, label_encoder, base_dir, preprocesso...
method __getitem__ (line 36) | def __getitem__(self, index):
method get_ordered_ids (line 42) | def get_ordered_ids(self):
method get_ordered_labels (line 45) | def get_ordered_labels(self):
method __len__ (line 48) | def __len__(self):
class ObjectCacher (line 51) | class ObjectCacher:
method __init__ (line 52) | def __init__(self, get_obj_func, dataset_name, obj_name="",
method get_obj_cache_path (line 91) | def get_obj_cache_path(self):
method get (line 94) | def get(self):
class PreprocessDataset (line 99) | class PreprocessDataset(Dataset):
method __init__ (line 106) | def __init__(self, dataset, preprocessor):
method __getitem__ (line 112) | def __getitem__(self, index):
method __len__ (line 114) | def __len__(self):
class FilesCachedDataset (line 119) | class FilesCachedDataset(Dataset):
method __init__ (line 120) | def __init__(self, get_dataset_func, dataset_name, x_name="",
method __getitem__ (line 148) | def __getitem__(self, index):
method get_ordered_labels (line 157) | def get_ordered_labels(self):
method get_ordered_ids (line 160) | def get_ordered_ids(self):
method get_xcache_path (line 163) | def get_xcache_path(self):
method get_ycache_path (line 166) | def get_ycache_path(self):
method get_sidcache_path (line 169) | def get_sidcache_path(self):
method __len__ (line 172) | def __len__(self):
class SelectionDataset (line 177) | class SelectionDataset(Dataset):
method __init__ (line 184) | def __init__(self, dataset, sample_ids):
method reselect (line 190) | def reselect(self, sample_ids):
method get_ordered_ids (line 194) | def get_ordered_ids(self):
method get_ordered_labels (line 197) | def get_ordered_labels(self):
method __getitem__ (line 202) | def __getitem__(self, index):
method __len__ (line 205) | def __len__(self):
class SimpleSelectionDataset (line 208) | class SimpleSelectionDataset(Dataset):
method __init__ (line 215) | def __init__(self, dataset, available_indexes ):
method __getitem__ (line 219) | def __getitem__(self, index):
method __len__ (line 222) | def __len__(self):
FILE: helpers/mixup.py
function my_mixup (line 5) | def my_mixup(size, alpha):
FILE: helpers/models_size.py
function count_non_zero_params (line 7) | def count_non_zero_params(model):
FILE: helpers/ramp.py
function pseudo_rampup (line 8) | def pseudo_rampup(T1, T2):
function exp_rampup (line 21) | def exp_rampup(rampup_length):
function linear_rampup (line 33) | def linear_rampup(rampup_length):
function linear_rampdown (line 45) | def linear_rampdown(rampdown_length, start=0, last_value=0):
function exp_rampdown (line 57) | def exp_rampdown(rampdown_length, num_epochs):
function cosine_rampdown (line 70) | def cosine_rampdown(rampdown_length, num_epochs):
function exp_warmup (line 83) | def exp_warmup(rampup_length, rampdown_length, num_epochs):
function exp_warmup_linear_down (line 93) | def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last...
function test_warmup (line 101) | def test_warmup():
function test_warmupl (line 107) | def test_warmupl():
function cosine_cycle (line 113) | def cosine_cycle(cycle_len=20,ramp_down_start=100,last_lr_value=0.01):
FILE: helpers/swa_callback.py
class StochasticWeightAveraging (line 35) | class StochasticWeightAveraging(Callback):
method __init__ (line 37) | def __init__(
method swa_start (line 127) | def swa_start(self) -> int:
method swa_end (line 131) | def swa_end(self) -> int:
method pl_module_contains_batch_norm (line 135) | def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
method setup (line 138) | def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'...
method on_fit_start (line 142) | def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
method on_train_epoch_start (line 161) | def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.L...
method on_train_epoch_end (line 200) | def on_train_epoch_end(self, trainer: 'pl.Trainer',pl_module: 'pl.Ligh...
method on_train_end (line 204) | def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
method on_validation_epoch_start (line 208) | def on_validation_epoch_start(self, trainer, pl_module) -> None:
method transfer_weights (line 217) | def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_modul...
method reset_batch_norm_and_save_state (line 221) | def reset_batch_norm_and_save_state(self, pl_module):
method reset_momenta (line 239) | def reset_momenta(self):
method update_parameters (line 247) | def update_parameters(
method avg_fn (line 262) | def avg_fn(
FILE: helpers/swa_legacy.py
class StochasticWeightAveraging (line 37) | class StochasticWeightAveraging(Callback):
method __init__ (line 39) | def __init__(
method swa_start (line 132) | def swa_start(self) -> int:
method swa_end (line 136) | def swa_end(self) -> int:
method pl_module_contains_batch_norm (line 140) | def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
method on_before_accelerator_backend_setup (line 143) | def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', p...
method on_fit_start (line 147) | def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
method on_train_epoch_start (line 172) | def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.L...
method on_train_epoch_end (line 211) | def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.Lig...
method on_train_end (line 214) | def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
method on_validation_epoch_start (line 217) | def on_validation_epoch_start(self, trainer, pl_module) -> None:
method transfer_weights (line 225) | def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_modul...
method reset_batch_norm_and_save_state (line 229) | def reset_batch_norm_and_save_state(self, pl_module):
method reset_momenta (line 247) | def reset_momenta(self):
method update_parameters (line 255) | def update_parameters(
method avg_fn (line 271) | def avg_fn(
FILE: helpers/workersinit.py
function worker_init_fn (line 6) | def worker_init_fn(x):
FILE: models/helpers/vit_helpers.py
function adapt_input_conv (line 27) | def adapt_input_conv(in_chans, conv_weight):
function load_pretrained (line 54) | def load_pretrained(
function overlay_external_default_cfg (line 144) | def overlay_external_default_cfg(default_cfg, kwargs):
function filter_kwargs (line 153) | def filter_kwargs(kwargs, names):
function set_default_kwargs (line 160) | def set_default_kwargs(kwargs, names, default_cfg):
function update_default_cfg_and_kwargs (line 180) | def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
function drop_path (line 203) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
class DropPath (line 225) | class DropPath(nn.Module):
method __init__ (line 228) | def __init__(self, drop_prob=None):
method forward (line 232) | def forward(self, x):
function _no_grad_trunc_normal_ (line 239) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
function trunc_normal_ (line 277) | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
function variance_scaling_ (line 297) | def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="no...
function lecun_normal_ (line 320) | def lecun_normal_(tensor):
function build_model_with_cfg (line 324) | def build_model_with_cfg(
FILE: models/passt.py
function _ntuple (line 26) | def _ntuple(n):
function _cfg (line 42) | def _cfg(url='', **kwargs):
function adapt_input_conv (line 246) | def adapt_input_conv(in_chans, conv_weight):
class Mlp (line 271) | class Mlp(nn.Module):
method __init__ (line 275) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 284) | def forward(self, x):
class PatchEmbed (line 298) | class PatchEmbed(nn.Module):
method __init__ (line 302) | def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3,...
method forward (line 318) | def forward(self, x):
class Attention (line 331) | class Attention(nn.Module):
method __init__ (line 332) | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pro...
method forward (line 343) | def forward(self, x):
class Block (line 364) | class Block(nn.Module):
method __init__ (line 366) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=...
method forward (line 377) | def forward(self, x):
class PaSST (line 383) | class PaSST(nn.Module):
method __init__ (line 391) | def __init__(self, u_patchout=0, s_patchout_t=0, s_patchout_f=0, img_s...
method init_weights (line 471) | def init_weights(self, mode=''):
method _init_weights (line 486) | def _init_weights(self, m):
method no_weight_decay (line 491) | def no_weight_decay(self):
method get_classifier (line 494) | def get_classifier(self):
method reset_classifier (line 500) | def reset_classifier(self, num_classes, global_pool=''):
method forward_features (line 506) | def forward_features(self, x):
method forward (line 576) | def forward(self, x):
function _init_vit_weights (line 598) | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: floa...
function resize_pos_embed (line 633) | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='...
function adapt_image_pos_embed_to_passt (line 656) | def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, gs_new=(), mode...
function checkpoint_filter_fn (line 679) | def checkpoint_filter_fn(state_dict, model):
function _create_vision_transformer (line 709) | def _create_vision_transformer(variant, pretrained=False, default_cfg=No...
function vit_huge_patch14_224_in21k (line 734) | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
function deit_base_distilled_patch16_384 (line 745) | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
function passt_s_swa_p16_128_ap476 (line 756) | def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
function passt_s_kd_p16_128_ap486 (line 769) | def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
function passt_l_kd_p16_128_ap47 (line 782) | def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
function passt_s_swa_p16_128_ap4761 (line 795) | def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
function passt_s_p16_128_ap472 (line 808) | def passt_s_p16_128_ap472(pretrained=False, **kwargs):
function passt_s_p16_s12_128_ap470 (line 821) | def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
function passt_s_f128_20sec_p16_s10_ap474_swa (line 834) | def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
function passt_s_f128_30sec_p16_s10_ap473_swa (line 842) | def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
function passt_s_swa_p16_s12_128_ap473 (line 850) | def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
function passt_s_p16_s14_128_ap469 (line 863) | def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):
function passt_s_swa_p16_s14_128_ap471 (line 876) | def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):
function passt_s_swa_p16_s16_128_ap473 (line 889) | def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):
function passt_s_p16_s16_128_ap468 (line 902) | def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
function fix_embedding_layer (line 923) | def fix_embedding_layer(model, embed="default"):
function lighten_model (line 933) | def lighten_model(model, cut_depth=0):
function get_model (line 958) | def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classe...
class EnsembelerModel (line 1021) | class EnsembelerModel(nn.Module):
method __init__ (line 1022) | def __init__(self, models):
method forward (line 1026) | def forward(self, x):
function get_ensemble_model (line 1040) | def get_ensemble_model(arch_list=[]):
FILE: models/preprocess.py
class AugmentMelSTFT (line 19) | class AugmentMelSTFT(nn.Module):
method __init__ (line 20) | def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, ...
method forward (line 57) | def forward(self, x):
method extra_repr (line 88) | def extra_repr(self):
FILE: openmic/dataset.py
function default_config (line 29) | def default_config():
function decode_mp3 (line 47) | def decode_mp3(mp3_arr):
function pad_or_truncate (line 65) | def pad_or_truncate(x, audio_length):
function get_ir_sample (line 77) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
function pydub_augment (line 96) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
class MixupDataset (line 107) | class MixupDataset(TorchDataset):
method __init__ (line 111) | def __init__(self, dataset, beta=2, rate=0.5):
method __getitem__ (line 117) | def __getitem__(self, index):
method __len__ (line 140) | def __len__(self):
class AudioSetDataset (line 145) | class AudioSetDataset(TorchDataset):
method __init__ (line 146) | def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip...
method open_hdf5 (line 166) | def open_hdf5(self):
method __len__ (line 169) | def __len__(self):
method __del__ (line 172) | def __del__(self):
method __getitem__ (line 177) | def __getitem__(self, index):
method resample (line 203) | def resample(self, waveform):
function get_base_training_set (line 221) | def get_base_training_set(openmic_train_hdf5):
function get_ft_weighted_sampler (line 227) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
function get_base_test_set (line 243) | def get_base_test_set(openmic_test_hdf5):
function get_roll_func (line 249) | def get_roll_func(axis=1, shift=None, shift_range=50):
function get_training_set (line 266) | def get_training_set(normalize, roll, wavmix=False):
function get_test_set (line 282) | def get_test_set(normalize):
function print_conf (line 292) | def print_conf(_config):
class DistributedSamplerWrapper (line 298) | class DistributedSamplerWrapper(DistributedSampler):
method __init__ (line 299) | def __init__(
method __iter__ (line 309) | def __iter__(self):
function default_command (line 327) | def default_command():
FILE: openmic/prepare_scripts/download_preprocess.py
function download (line 21) | def download(force=False):
function untar (line 29) | def untar():
function process_folder (line 36) | def process_folder(fol="balanced_train_segments"):
function process_one (line 50) | def process_one(i, f1, f2):
function make_mp3 (line 56) | def make_mp3():
function read_metadata (line 60) | def read_metadata(csv_path, classes_num, id_to_ix, openmicf):
function get_files_labels (line 95) | def get_files_labels(balanced_csv, balanced_audio_path, d_files, openmic...
function pack (line 118) | def pack():
function preprocess (line 164) | def preprocess():
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (364K chars).
[
{
"path": ".github/workflows/pylint.yml",
"chars": 553,
"preview": "name: Pylint\n\non: [push]\n\njobs:\n build:\n runs-on: ubuntu-latest\n strategy:\n matrix:\n python-version: "
},
{
"path": ".gitignore",
"chars": 1540,
"preview": "\n\n# project\n/output/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\nexample.py\ntimit_data/\nL"
},
{
"path": ".markdownlint.json",
"chars": 119,
"preview": "{\n \"MD033\": {\n \"allowed_elements\": [\n \"p\",\n \"img\"\n ]\n },\n \"MD013\": false\n}"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 14253,
"preview": "# PaSST: Efficient Training of Audio Transformers with Patchout\n\nThis is the implementation for [Efficient Training of A"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "audioset/README.md",
"chars": 2323,
"preview": "# Experiments on Audioset\nAudioset has around 2M segments. The total size of the dataset with wav files with `32khz` sam"
},
{
"path": "audioset/dataset.py",
"chars": 13659,
"preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
},
{
"path": "audioset/prepare_scripts/convert_to_mp3.py",
"chars": 1955,
"preview": "import argparse\nimport multiprocessing\nimport glob\nimport os\n\n# Replace this with the dataset downloaded using PANN scri"
},
{
"path": "audioset/prepare_scripts/create_h5pymp3_dataset.py",
"chars": 5839,
"preview": "# %%\nimport h5py\nimport pandas as pd\nimport numpy as np\nimport csv\nimport os\n\n\n# %%\nbase_dir = \"../../audioset_hdf5s/\"\nb"
},
{
"path": "ba3l/__init__.py",
"chars": 331,
"preview": "\"\"\"Package info\"\"\"\n\n__version__ = \"0.0.2\"\n__author__ = \"Koutini et al.\"\n__license__ = \"Apache-2.0\"\n__copyright__ = \"Copy"
},
{
"path": "ba3l/experiment.py",
"chars": 8033,
"preview": "import inspect\nfrom importlib import import_module\n\nfrom ba3l.ingredients.datasets import Datasets\nfrom ba3l.ingredients"
},
{
"path": "ba3l/ingredients/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ba3l/ingredients/datasets.py",
"chars": 11133,
"preview": "import inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sac"
},
{
"path": "ba3l/ingredients/ingredient.py",
"chars": 4965,
"preview": "import inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sac"
},
{
"path": "ba3l/ingredients/models.py",
"chars": 6967,
"preview": "import inspect\nimport os\n\nfrom .ingredient import Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.u"
},
{
"path": "ba3l/ingredients/trainer.py",
"chars": 440,
"preview": "import inspect\nimport os\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom .datasets import Datasets, raise_"
},
{
"path": "ba3l/module.py",
"chars": 1158,
"preview": "import pytorch_lightning as pl\n\nimport warnings\nfrom abc import ABC\n\nimport torch.distributed as dist\nfrom munch import "
},
{
"path": "ba3l/plutils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ba3l/plutils/lr_monitor.py",
"chars": 8110,
"preview": "# Copyright The PyTorch Lightning team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may no"
},
{
"path": "ba3l/plutils/progress_bar.py",
"chars": 6006,
"preview": "\nimport importlib\nimport io\nimport os\nimport sys\n\n# check if ipywidgets is installed before importing tqdm.auto\n# to ens"
},
{
"path": "ba3l/util/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ba3l/util/functions.py",
"chars": 286,
"preview": "import inspect\nfrom collections import OrderedDict\n\n\ndef get_default_kwargs_dict(f):\n sig = inspect.signature(f)\n "
},
{
"path": "ba3l/util/sacred_logger.py",
"chars": 2507,
"preview": "from pytorch_lightning.utilities import rank_zero_only\ntry:\n from pytorch_lightning.loggers import Logger as Lightnin"
},
{
"path": "config_updates.py",
"chars": 11551,
"preview": "from sacred.config_helpers import DynamicIngredient, CMD\n\n\ndef add_configs(ex):\n '''\n This functions add generic c"
},
{
"path": "environment.yml",
"chars": 2413,
"preview": "name: ba3l\nchannels:\n - defaults\ndependencies:\n - _libgcc_mutex=0.1\n - _openmp_mutex=5.1\n - ca-certificates=2022.10."
},
{
"path": "environment_old_2021.yml",
"chars": 3487,
"preview": "name: ba3l\nchannels:\n - conda-forge\n - defaults\ndependencies:\n - _libgcc_mutex=0.1\n - _openmp_mutex=4.5\n - _pytorch"
},
{
"path": "esc50/README.md",
"chars": 1975,
"preview": "# Experiments on ESC-50 Environmental Sound Classification\n[ESC-50](https://github.com/karolpiczak/ESC-50) consist of 20"
},
{
"path": "esc50/dataset.py",
"chars": 10201,
"preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
},
{
"path": "ex_audioset.py",
"chars": 19493,
"preview": "import os\nimport sys\nimport PIL\nimport pytorch_lightning\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheck"
},
{
"path": "ex_esc50.py",
"chars": 15106,
"preview": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers im"
},
{
"path": "ex_fsd50k.py",
"chars": 18025,
"preview": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers im"
},
{
"path": "ex_openmic.py",
"chars": 17683,
"preview": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers im"
},
{
"path": "fsd50k/README.md",
"chars": 3060,
"preview": "# Experiments on FSD50K\n The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432)) consists of 51K audio clips a"
},
{
"path": "fsd50k/dataset.py",
"chars": 12238,
"preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
},
{
"path": "fsd50k/prepare_scripts/convert_to_mp3.py",
"chars": 1305,
"preview": "import multiprocessing\nimport glob\nimport os\nimport sys\n\nif len(sys.argv) > 1:\n FSD50K_base = sys.argv[1] # the path "
},
{
"path": "fsd50k/prepare_scripts/create_h5pymp3_dataset.py",
"chars": 2753,
"preview": "# %%\nimport sys\n\nimport h5py\nimport pandas as pd\nimport numpy as np\nimport csv\nimport os\n\n# %%\nfrom numpy import dtype\n\n"
},
{
"path": "helpers/audiodatasets.py",
"chars": 7509,
"preview": "import hashlib\nimport os\nimport time\n\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom os.path import expanduser\n"
},
{
"path": "helpers/mixup.py",
"chars": 400,
"preview": "\nimport numpy as np\nimport torch\n\ndef my_mixup(size, alpha):\n rn_indices = torch.randperm(size)\n lambd = np.random"
},
{
"path": "helpers/models_size.py",
"chars": 1069,
"preview": "\n\n\n\n\n\ndef count_non_zero_params(model):\n sum_params = 0\n sum_non_zero = 0\n desc = \"\"\n\n def calc_params(model"
},
{
"path": "helpers/ramp.py",
"chars": 3635,
"preview": "import numpy as np\nfrom ba3l.ingredients.ingredient import Ingredient\n\n\n# credit: https://github.com/iBelieveCJM/Tricks-"
},
{
"path": "helpers/swa_callback.py",
"chars": 11325,
"preview": "# Adapted from PyTorch Lightning so that it only does the averaging\n# Copyright The PyTorch Lightning team.\n#\n# Licensed"
},
{
"path": "helpers/swa_legacy.py",
"chars": 11761,
"preview": "# Adapted from PyTorch Lightning so that it only does the averaging\n# Copyright The PyTorch Lightning team.\n#\n# Licensed"
},
{
"path": "helpers/workersinit.py",
"chars": 251,
"preview": "import torch\nimport numpy as np\nimport random\n\n\ndef worker_init_fn(x):\n seed = (torch.initial_seed() + x * 1000) % 2 "
},
{
"path": "models/helpers/vit_helpers.py",
"chars": 15583,
"preview": "\"\"\"\nAdapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\nCredit "
},
{
"path": "models/passt.py",
"chars": 52057,
"preview": "\"\"\"\nMost of this code comes from the timm library.\nWe tried to disentangle from the timm library version.\n\nAdapted from"
},
{
"path": "models/preprocess.py",
"chars": 3638,
"preview": "import torch.nn as nn\nimport torchaudio\nfrom torch.nn.functional import conv1d, conv2d\n\nimport torch\n\n\nfrom ba3l.ingredi"
},
{
"path": "openmic/README.md",
"chars": 884,
"preview": "# Experiments on OpenMIC-2018\n\n[OpenMIC-2018](https://github.com/cosmir/openmic-2018) ([zenodo](https://zenodo.org/recor"
},
{
"path": "openmic/dataset.py",
"chars": 10659,
"preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
},
{
"path": "openmic/prepare_scripts/download_preprocess.py",
"chars": 6389,
"preview": "import os\nimport tarfile\nimport multiprocessing\nimport glob\nimport h5py\nimport numpy as np\n\nfrom torch.hub import downlo"
},
{
"path": "pip_list.txt",
"chars": 3141,
"preview": "Package Version\n----------------------- ------------\nabsl-py 1.4.0\naiohttp "
},
{
"path": "requirements.txt",
"chars": 151,
"preview": "av>=10.0.0\nh5py>=3.8.0\njsonpickle>=3.0.1\nkk-sacred>=0.8.4\nlibrosa>=0.10.0.post2\ntimm>=0.4.12\ntorchmetrics>=0.2.0\npytorch"
}
]
About this extraction
This page contains the full source code of the kkoutini/PaSST GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (341.1 KB), approximately 89.4k tokens, and a symbol index with 461 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.