Full Code of kkoutini/PaSST for AI

main 2a5c818afcc2 cached
52 files
341.1 KB
89.4k tokens
461 symbols
1 requests
Download .txt
Showing preview only (358K chars total). Download the full file or copy to clipboard to get everything.
Repository: kkoutini/PaSST
Branch: main
Commit: 2a5c818afcc2
Files: 52
Total size: 341.1 KB

Directory structure:
gitextract_blhd6h9_/

├── .github/
│   └── workflows/
│       └── pylint.yml
├── .gitignore
├── .markdownlint.json
├── LICENSE
├── README.md
├── __init__.py
├── audioset/
│   ├── README.md
│   ├── dataset.py
│   └── prepare_scripts/
│       ├── convert_to_mp3.py
│       └── create_h5pymp3_dataset.py
├── ba3l/
│   ├── __init__.py
│   ├── experiment.py
│   ├── ingredients/
│   │   ├── __init__.py
│   │   ├── datasets.py
│   │   ├── ingredient.py
│   │   ├── models.py
│   │   └── trainer.py
│   ├── module.py
│   ├── plutils/
│   │   ├── __init__.py
│   │   ├── lr_monitor.py
│   │   └── progress_bar.py
│   └── util/
│       ├── __init__.py
│       ├── functions.py
│       └── sacred_logger.py
├── config_updates.py
├── environment.yml
├── environment_old_2021.yml
├── esc50/
│   ├── README.md
│   └── dataset.py
├── ex_audioset.py
├── ex_esc50.py
├── ex_fsd50k.py
├── ex_openmic.py
├── fsd50k/
│   ├── README.md
│   ├── dataset.py
│   └── prepare_scripts/
│       ├── convert_to_mp3.py
│       └── create_h5pymp3_dataset.py
├── helpers/
│   ├── audiodatasets.py
│   ├── mixup.py
│   ├── models_size.py
│   ├── ramp.py
│   ├── swa_callback.py
│   ├── swa_legacy.py
│   └── workersinit.py
├── models/
│   ├── helpers/
│   │   └── vit_helpers.py
│   ├── passt.py
│   └── preprocess.py
├── openmic/
│   ├── README.md
│   ├── dataset.py
│   └── prepare_scripts/
│       └── download_preprocess.py
├── pip_list.txt
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/pylint.yml
================================================
name: Pylint

on: [push]

jobs:
  build:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.8", "3.9", "3.10"]
    steps:
    - uses: actions/checkout@v3
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v3
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install pylint
    - name: Analysing the code with pylint
      run: |
        pylint $(git ls-files '*.py')


================================================
FILE: .gitignore
================================================


# project
/output/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
example.py
timit_data/
LJSpeech-1.1/

# C extensions
*.so

.idea/

# Distribution / packaging
.Python
env/
ide_layouts/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
/openmic/prepare_scripts/openmic-2018-v1.0.0/
/openmic/prepare_scripts/openmic-2018-v1.0.0.tgz
/audioset_hdf5s/mp3/openmic_test.csv_mp3.hdf
/audioset_hdf5s/mp3/openmic_train.csv_mp3.hdf
environment_builds.yml

# Output
lightning_logs/*
audioset_hdf5s/*
.vscode/settings.json
wandb/*
.vscode/launch.json


================================================
FILE: .markdownlint.json
================================================
{
    "MD033": {
        "allowed_elements": [
            "p",
            "img"
        ]
    },
    "MD013": false
}

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# PaSST: Efficient Training of Audio Transformers with Patchout

This is the implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069)

Patchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.

<p align="center"><img src="https://github.com/kkoutini/PaSST/blob/main/.github/speed_mem_map.png?raw=true" width="600"/></p>

Patchout works by dropping out some of the input patches during training.
 In either an unstructured way (randomly, similar to dropout),
or entire time-frames or frequency bins of the extracted patches (similar to SpecAugment),
 which corresponds to rows/columns in step 3 of the figure below.  

<p align="center"><img src="https://github.com/kkoutini/PaSST/raw/main/.github/passt_diag.png?raw=true" width="600"/></p>

## Table of contents

- [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions)
  - [Getting the logits from the pretrained models](#getting-the-logits-from-the-pretrained-models)
  - [Getting a pre-trained model for fine-tuning](#getting-a-pre-trained-model-for-fine-tuning)
- [Development environment](#development-environment)
  - [Setting up the development experiments environment](#setting-up-the-development-experiments-environment)
  - [Setting up using the exported conda environment](#setting-up-using-the-exported-conda-environment)
  - [Checking the environment](#checking-the-environment)
- [Getting started](#getting-started)
  - [General information](#general-information)
  - [Configuring the experiment](#configuring-the-experiment)
- [Training on Audioset](#training-on-audioset)
- [Examples with Pre-trained models](#examples-with-pre-trained-models)
- [Examples fine-tuning on downstream datasets](#examples-of-fine-tuning-on-downstream-datasets)
- [Citation](#citation)
- [Contact](#contact)

## Pre-trained models for Inference and embeddings extractions

If you only want to use the embeddings generated by the pretrained models, use
your own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo [here](https://github.com/kkoutini/passt_hear21).
The package follows [HEAR 2021 NeurIPS Challenge](https://neuralaudio.ai/hear2021-results.html) API, and can be installed:

```shell
pip install hear21passt
```

This repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks.

### Getting the logits from the pretrained models

```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.
print(model.net) # the transformer network.

# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
    # example audio_wave of batch=3 and 10 seconds
    audio = torch.ones((3, 32000 * 10))*0.5
    audio_wave = audio.cuda()
    logits=model(audio_wave) 
```

### Getting a pre-trained model for fine tuning

```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.

# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)
print(model.net) # the transformer network.


# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k

model.train()
model = model.cuda()

```

## Development environment

If you want to use the same environment as in the paper, you can follow the instructions below.

### Setting up the development experiments environment

For training models from scratch or fine-tuning using the same setup as in the paper:

1. If needed, create a new environment with python 3.8 and activate it:

```bash
conda create -n passt python=3.8
conda activate passt
 ```

1. Install pytorch build that suits your system. For example:

```bash
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

 ```

1. Install the requirements:

 ```bash
pip install -r requirements.txt
 ```

### Setting up using the exported conda environment

Alternatively, you can use the exported conda environment `environment.yml` to create the environment.

For setting up [Mamba](https://github.com/mamba-org/mamba) is recommended since it works faster than `conda`:

```shell
conda install mamba -n base -c conda-forge
```

Now you can import the environment from `environment.yml`

```shell
mamba env create -f environment.yml
```

Now you have an environment named `ba3l`.

### Checking the environment

In order to check if your environment matched the environment we used in our runs, please check the `environment.yml` and `pip_list.txt` files, which were exported using:

```shell
conda env export --no-builds | grep -v "prefix" > environment.yml
pip list > pip_list.txt
```

## Getting started

If you want to use your setup and only use the models from this repo, you can get the models train them from scratch or fine-tune them on your own dataset as explained above [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions). The rest of this section explains using this repo for training and fine-tuning the models. For that, first you need to set up the development environment as explained above.

### General information

The repo is built using [sacred](https://sacred.readthedocs.io/en/) for experiment management and configuration, pytorch-lightning for training, and wandb for logging.

Each dataset has a main experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder. The experiment file contains the main training and validation logic. The dataset folder contains the dataset specific code needed to download, preprocess and load the dataset for training.

In general, you can prob the experiment file for help, this will print the available commands and basic options:

```shell
python ex_audioset.py help
```

### Configuring the experiment

Each experiment has a set of default configuration options, defined in the experiment file, e.g. `ex_audioset.py`. You can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html). You can use the `print_config` command to print the configuration values without training a model:

```shell
 python ex_audioset.py print_config
 ```

You can use then use the command line interface to override any of the configuration options ([sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html)), using `with` e.g.:

```shell
python ex_audioset.py with trainer.precision=16 
```

This will train on Audioset using 16-bit precision.

The overall configurations look like this:

```yaml
  ...
  seed = 542198583                  # the random seed for this experiment
  slurm_job_id = ''
  speed_test_batch_size = 100
  swa = True
  swa_epoch_start = 50
  swa_freq = 5
  use_mixup = True
  warm_up_len = 5
  weight_decay = 0.0001
  basedataset:
    base_dir = 'audioset_hdf5s/'     # base directory of the dataset, change it or make a link
    eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
    wavmix = 1
    ....
    roll_conf:
      axis = 1
      shift = None
      shift_range = 50
  datasets:
    test:
      batch_size = 20
      dataset = {CMD!}'/basedataset.get_test_set'
      num_workers = 16
      validate = True
    training:
      batch_size = 12
      dataset = {CMD!}'/basedataset.get_full_training_set'
      num_workers = 16
      sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
      shuffle = None
      train = True
  models:
    mel:
      freqm = 48
      timem = 192
      hopsize = 320
      htk = False
      n_fft = 1024
      n_mels = 128
      norm = 1
      sr = 32000
      ...
    net:
      arch = 'passt_s_swa_p16_128_ap476'
      fstride = 10
      in_channels = 1
      input_fdim = 128
      input_tdim = 998
      n_classes = 527
      s_patchout_f = 4
      s_patchout_t = 40
      tstride = 10
      u_patchout = 0
      ...
  trainer:
    accelerator = None
    accumulate_grad_batches = 1
    amp_backend = 'native'
    amp_level = 'O2'
    auto_lr_find = False
    auto_scale_batch_size = False
    ...
```

There are many things that can be updated from the command line.
In short:

- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line.
- `wandb` is the wandb configuration. For example, to change the wandb project `wandb.project="test_project"` to the command line.
- `models.net` are the PaSST (or the chosen NN) options. Examples: `models.net.u_patchout`,  `models.net.s_patchout_f` `models.net.s_patchout_t` control the unstructured patchout and structured patchout over frequency and time. `input_fdim` and `input_tdim` are the input spectrogram dimensions over frequency and time. `models.net.fstride` and `models.net.tstride` are the strides of the input patches over frequency and time, setting these to 16 means no patch overlap.
- `models.mel` are the preprocessing options (mel spectrograms). `mel.sr` is the sampling rate, `mel.hopsize` is the hop size of the STFT window, `mel.n_mels` is the number of mel bins, `mel.freqm` and `mel.timem` are the frequency and time masking parameters of spec-augment.

There are many pre-defined configuration bundles (called named_configs) in `config_updates.py`. These include different models, setups etc...
You can list these configurations with:

```shell
python ex_audioset.py print_named_configs
```

For example, `passt_s_20sec` is a configuration bundle that sets the model to PaSST-S pre-trained on Audioset, and accepts up to 20 second clips.

## Training on Audioset

Download and prepare the dataset as explained in the [audioset page](audioset/)

The base PaSST model can be trained for example like this:

```bash
python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p
```

For example using only unstructured patchout of 400:

```bash
python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384  models.net.u_patchout=400  models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p
```

Multi-gpu training can be enabled by setting the environment variable `DDP`, for example with 2 gpus:

```shell
 DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -c "PaSST base 2 GPU"
```

## Examples with Pre-trained models

Please check the [releases page](https://github.com/kkoutini/PaSST/releases/), to download pre-trained models.
In general, you can get a pretrained model on Audioset using

```python
from models.passt import get_model
model  = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1,
                   fstride=10, tstride=10,input_fdim=128, input_tdim=998,
                   u_patchout=0, s_patchout_t=40, s_patchout_f=4)
```

this will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs.

There are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`.
For example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479.

In general, you can get a pretrained model using:

```python
from models.passt import get_model
passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)
```

Using the framework, you can evaluate this model using:

```shell
python ex_audioset.py evaluate_only with  trainer.precision=16  passt_s_swa_p16_s16_128_ap473 -p
```

Ensemble of these models are provided as well:
A large ensemble giving `mAP=.4956`

```shell
python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many
```

An ensemble of 2 models with `stride=14` and `stride=16` giving `mAP=.4858`

```shell
python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s16_14
```

As well as other ensembles `ensemble_4`, `ensemble_5`

## Examples of fine-tuning on downstream datasets

1. [ESC-50: Dataset for Environmental Sound Classification](esc50/)
2. [OpenMIC-2018 dataset](openmic/)
3. [FSD50K](fsd50k/)

## Citation

The citation to the accepted paper in Interspeech 2022:

```bib
@inproceedings{koutini22passt,
  author       = {Khaled Koutini and
                  Jan Schl{\"{u}}ter and
                  Hamid Eghbal{-}zadeh and
                  Gerhard Widmer},
  title        = {Efficient Training of Audio Transformers with Patchout},
  booktitle    = {Interspeech 2022, 23rd Annual Conference of the International Speech
                  Communication Association, Incheon, Korea, 18-22 September 2022},
  pages        = {2753--2757},
  publisher    = {{ISCA}},
  year         = {2022},
  url          = {https://doi.org/10.21437/Interspeech.2022-227},
  doi          = {10.21437/Interspeech.2022-227},
}
```

## Contact

The repo will be updated, in the meantime if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.


================================================
FILE: __init__.py
================================================


================================================
FILE: audioset/README.md
================================================
# Experiments on Audioset
Audioset has around 2M segments. The total size of the dataset with wav files with `32khz` sampling rate is around 1.2 TB. In our setup, this results in a huge IO bottleneck that slows down the training process significantly.
Therefore, we encode the dataset to mp3, pack the mp3 into HDF5 format and decode the mp3s on the fly, If you have enough cpu cores (10-16 dataloading workers) you should not notice any slowdowns.

in the `dataset.py` file we read the samples from the hdf files. Decode the mp3, do wave form augmentations and return the raw waveform of the model.
`AudioSetDataset` is the main class where reading from the hdf files.



## Preparing the dataset
###  Downloading Audioset
We used the scripts provided by [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn) to download the dataset.
###  Converting to mp3
Once the Datasets are downloaded we convert all the files to mp3 using the script:
`prepare_scripts/convert_to_mp3.py`.

```bash
python convert_to_mp3.py --source pann_download_folder --out mp3_folder
```

this will significantly reduce the size of the dataset and overcome the IO bottleneck in our setup. The trade-off is that more cpu is needed during training to decode the mp3s. 
We use the [av](https://pypi.org/project/av/) (check `decode_mp3` in ` dataset.py`) library to decode the mp3 in the data loading workers, this is much faster than calling ffmpeg.
As a result, approximetly 10 decoding threads should be enough keep a 2080ti busy.

you can test how much time it take to load and decode one epoch on your system:
```bash
python python ex_audioset.py test_loaders_train_speed
```

This step is not necessary if you have a more powerful setup and the `decode_mp3` also supports other ffmpeg codecs.

###  packing to HDF5 files

Finally, you need to pack the mp3 files into a single HDF5 file using `create_h5pymp3_dataset.py`.
you just need to set the paths in the script to match your local paths. The script goes through the csv files and check if the corresponding mp3 file exists, then it will store it in h5py file.
The output of this step should be 3 files `balanced_train_segments_mp3.hdf`, `eval_segments_mp3.hdf` and `unbalanced_train_segments_mp3.hdf`.
Each of these files. Make sure the paths match the default config in `dataset.py`

================================================
FILE: audioset/dataset.py
================================================
import io
import os
import pathlib
import random

import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler

import torch
from ba3l.ingredients.datasets import Dataset
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
import numpy as np
from helpers.audiodatasets import  PreprocessDataset
import h5py


LMODE = os.environ.get("LMODE", False)
#$TMPDIR
dataset = Dataset('audiodataset')


@dataset.config
def default_config():
    name = 'audioset'  # dataset name
    normalize = False  # normalize dataset
    subsample = False  # subsample squares from the dataset
    roll = True  # apply roll augmentation
    fold = 1
    base_dir = "audioset_hdf5s/"  # base directory of the dataset, change it or make a link
    if LMODE:
        base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/"

    balanced_train_hdf5 = base_dir + "mp3/balanced_train_segments_mp3.hdf"
    eval_hdf5 = base_dir + "mp3/eval_segments_mp3.hdf"
    unbalanced_train_hdf5 = base_dir + "mp3/unbalanced_train_segments_mp3.hdf"
    if LMODE:
        balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
        unbalanced_train_hdf5 = unbalanced_train_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
        eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get("TMPDIR", base_dir)+"/")
    ir_path = base_dir + "irs/"
    num_of_classes = 527

if LMODE:
    @dataset.config
    def LMODE_default_config():
        cache_root_path = "/system/user/publicdata/CP/DCASE/cached_datasets/"





def decode_mp3(mp3_arr):
    """
    decodes an array if uint8 representing an mp3 file
    :rtype: np.array
    """
    container = av.open(io.BytesIO(mp3_arr.tobytes()))
    stream = next(s for s in container.streams if s.type == 'audio')
    # print(stream)
    a = []
    for i, packet in enumerate(container.demux(stream)):
        for frame in packet.decode():
            a.append(frame.to_ndarray().reshape(-1))
    waveform = np.concatenate(a)
    if waveform.dtype != 'float32':
        raise RuntimeError("Unexpected wave type")
    return waveform


def pad_or_truncate(x, audio_length):
    """Pad all audio to specific length."""
    if len(x) <= audio_length:
        return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
    else:
        return x[0: audio_length]


irs_arr = None


@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
    if not ir_augment:
        return
    global irs_arr
    if irs_arr is None:
        all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
        all_paths = sorted(all_paths)
        if cut_irs_offset is not None:
            all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
        all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
        print("will use these IRs:")
        for i in range(len(all_paths_name)):
            print(i, ": ", all_paths_name[i])
        _run.info["ir_devices"] = all_paths_name
        irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
    return irs_arr[int(np.random.randint(0, len(irs_arr)))]


@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
    if ir_augment and torch.rand(1) < ir_augment:
        ir = get_ir_sample()
        waveform = convolve(waveform, ir, 'full')
    if gain_augment:
        gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
        amp = 10 ** (gain / 20)
        waveform = waveform * amp
    return waveform


class MixupDataset(TorchDataset):
    """ Mixing Up wave forms
    """

    def __init__(self, dataset, beta=2, rate=0.5):
        self.beta = beta
        self.rate = rate
        self.dataset = dataset
        print(f"Mixing up waveforms from dataset of len {len(dataset)}")

    def __getitem__(self, index):
        if torch.rand(1) < self.rate:
            x1, f1, y1 = self.dataset[index]
            idx2 = torch.randint(len(self.dataset), (1,)).item()
            x2, f2, y2 = self.dataset[idx2]
            l = np.random.beta(self.beta, self.beta)
            l = max(l, 1. - l)
            x1 = x1-x1.mean()
            x2 = x2-x2.mean()
            x = (x1 * l + x2 * (1. - l))
            x = x - x.mean()
            return x, f1, (y1 * l + y2 * (1. - l))
        return self.dataset[index]

    def __len__(self):
        return len(self.dataset)


class AudioSetDataset(TorchDataset):
    def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False):
        """
        Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
        """
        self.sample_rate = sample_rate
        self.hdf5_file = hdf5_file
        if in_mem:
            print("\nPreloading in memory\n")
            with open(hdf5_file, 'rb') as f:
                self.hdf5_file = io.BytesIO(f.read())
        with h5py.File(hdf5_file, 'r') as f:
            self.length = len(f['audio_name'])
            print(f"Dataset from {hdf5_file} with length {self.length}.")
        self.dataset_file = None  # lazy init
        self.clip_length = clip_length * sample_rate
        self.classes_num = classes_num
        self.augment = augment
        if augment:
            print(f"Will agument data from {hdf5_file}")

    def open_hdf5(self):
        self.dataset_file = h5py.File(self.hdf5_file, 'r')

    def __len__(self):
        return self.length

    def __del__(self):
        if self.dataset_file is not None:
            self.dataset_file.close()
            self.dataset_file = None

    def __getitem__(self, index):
        """Load waveform and target of an audio clip.

        Args:
          meta: {
            'hdf5_path': str,
            'index_in_hdf5': int}
        Returns:
          data_dict: {
            'audio_name': str,
            'waveform': (clip_samples,),
            'target': (classes_num,)}
        """
        if self.dataset_file is None:
            self.open_hdf5()

        audio_name = self.dataset_file['audio_name'][index].decode()
        waveform = decode_mp3(self.dataset_file['mp3'][index])
        if self.augment:
            waveform = pydub_augment(waveform)
        waveform = pad_or_truncate(waveform, self.clip_length)
        waveform = self.resample(waveform)
        target = self.dataset_file['target'][index]
        target = np.unpackbits(target, axis=-1,
                               count=self.classes_num).astype(np.float32)
        return waveform.reshape(1, -1), audio_name, target

    def resample(self, waveform):
        """Resample.
        Args:
          waveform: (clip_samples,)
        Returns:
          (resampled_clip_samples,)
        """
        if self.sample_rate == 32000:
            return waveform
        elif self.sample_rate == 16000:
            return waveform[0:: 2]
        elif self.sample_rate == 8000:
            return waveform[0:: 4]
        else:
            raise Exception('Incorrect sample rate!')


@dataset.command
def get_base_training_set(balanced_train_hdf5):
    ds = AudioSetDataset(balanced_train_hdf5, augment=True)
    return ds


@dataset.command
def get_unbalanced_training_set(unbalanced_train_hdf5):
    ds = AudioSetDataset(unbalanced_train_hdf5, augment=True)
    return ds



@dataset.command
def get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5):
    ds = ConcatDataset(
        [AudioSetDataset(balanced_train_hdf5, augment=False), AudioSetDataset(unbalanced_train_hdf5, augment=False)])
    return ds


@dataset.command
def get_base_full_training_set():
    sets = [get_base_training_set(), get_unbalanced_training_set()]
    ds = ConcatDataset(sets)
    return ds


@dataset.command
def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes):
    for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
        print(f"\n \n will now preload {hdf5_file} \n\n ")
        with h5py.File(hdf5_file, 'r') as dataset_file:
            target = dataset_file['mp3'][:]
            print(len(target))
            print(f"\n \n done with  {hdf5_file} \n\n ")
    return target[1000]


@dataset.command
def get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes,
                                       sample_weight_offset=100, sample_weight_sum=True):
    """
    :return: float tenosr of shape len(full_training_set) representing the weights of each sample.
    """
    # the order of balanced_train_hdf5,unbalanced_train_hdf5 is important.
    # should match get_full_training_set
    all_y = []
    for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:
        with h5py.File(hdf5_file, 'r') as dataset_file:
            target = dataset_file['target']
            target = np.unpackbits(target, axis=-1,
                                   count=num_of_classes)
            all_y.append(target)
    all_y = np.concatenate(all_y, axis=0)
    all_y = torch.as_tensor(all_y)
    per_class = all_y.long().sum(0).float().reshape(1, -1)  # frequencies per class

    per_class = sample_weight_offset + per_class  # offset low freq classes
    if sample_weight_offset > 0:
        print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}")
    per_class_weights = 1000. / per_class
    all_weight = all_y * per_class_weights
    # print(all_weight.shape)
    # print(all_weight[1510])
    if sample_weight_sum:
        print("\nsample_weight_sum\n")
        all_weight = all_weight.sum(dim=1)
    else:
        all_weight, _ = all_weight.max(dim=1)
    # print(all_weight.shape)
    # print(all_weight[1510])
    return all_weight


@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
                            epoch_len=100000, sampler_replace=False):
    num_nodes = int(os.environ.get('num_nodes', 1))
    ddp = int(os.environ.get('DDP', 1))
    num_nodes = max(ddp, num_nodes)
    print("num_nodes= ", num_nodes)
    rank = int(os.environ.get('NODE_RANK', 0))
    return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
                                                                   num_samples=epoch_len, replacement=sampler_replace),
                                     dataset=range(epoch_len),
                                     num_replicas=num_nodes,
                                     rank=rank,
                                     )


@dataset.command
def get_base_test_set(eval_hdf5):
    ds = AudioSetDataset(eval_hdf5)
    return ds


@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
    print("rolling...")

    def roll_func(b):
        x, i, y = b
        x = torch.as_tensor(x)
        sf = shift
        if shift is None:
            sf = int(np.random.random_integers(-shift_range, shift_range))
        global FirstTime

        return x.roll(sf, axis), i, y

    return roll_func


@dataset.command
def get_training_set(normalize, roll, wavmix=False):
    ds = get_base_training_set()
    get_ir_sample()
    if normalize:
        print("normalized train!")
        fill_norms()
        ds = PreprocessDataset(ds, norm_func)
    if roll:
        ds = PreprocessDataset(ds, get_roll_func())
    if wavmix:
        ds = MixupDataset(ds)

    return ds


@dataset.command
def get_full_training_set(normalize, roll, wavmix=False):
    ds = get_base_full_training_set()
    get_ir_sample()
    if normalize:
        print("normalized train!")
        fill_norms()
        ds = PreprocessDataset(ds, norm_func)
    if roll:
        ds = PreprocessDataset(ds, get_roll_func())
    if wavmix:
        ds = MixupDataset(ds)
    return ds



@dataset.command
def get_test_set(normalize):
    ds = get_base_test_set()
    if normalize:
        print("normalized test!")
        fill_norms()
        ds = PreprocessDataset(ds, norm_func)
    return ds


@dataset.command
def print_conf(_config):
    print("Config of ", dataset.path, id(dataset))
    print(_config)
    print()


class DistributedSamplerWrapper(DistributedSampler):
    def __init__(
            self, sampler, dataset,
            num_replicas=None,
            rank=None,
            shuffle: bool = True):
        super(DistributedSamplerWrapper, self).__init__(
            dataset, num_replicas, rank, shuffle)
        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
        self.sampler = sampler

    def __iter__(self):
        if self.sampler.generator is None:
            self.sampler.generator = torch.Generator()
        self.sampler.generator.manual_seed(self.seed + self.epoch)
        indices = list(self.sampler)
        if self.epoch == 0:
            print(f"\n DistributedSamplerWrapper :  {indices[:10]} \n\n")
        indices = indices[self.rank:self.total_size:self.num_replicas]
        return iter(indices)


if __name__ == "__main__":
    from sacred import Experiment

    ex = Experiment("test_dataset", ingredients=[dataset])


    @ex.automain
    def default_command():
        ex.current_run.get_command_function("print_config")()
        get_base_training_set()
        ds = get_test_set()
        print(ds[0])
        ds = get_training_set()
        print(ds[0])
        print("get_base_training_set", len(get_base_training_set()))
        print("get_base_test_set", len(get_base_test_set()))
        print("get_training_set", len(get_training_set()))
        print("get_test_set", len(get_test_set()))


================================================
FILE: audioset/prepare_scripts/convert_to_mp3.py
================================================
import argparse
import multiprocessing
import glob
import os

# Replace this with the dataset downloaded using PANN scripts
source_path = "/share/cp/datasets/full_audioset/audioset201906/audios/"

# Replace with the output directory
out_path = "./mp3_audio/"

all_num = 0


def process_folder(fol="balanced_train_segments"):
    print("now working on ", fol)
    os.makedirs(out_path + fol, exist_ok=True)
    all_files = list(glob.glob(source_path + fol + "/*.wav"))
    print(f"it has {len(all_files)}")
    global all_num
    all_num = len(all_files)
    cmds = [(i, file, out_path + fol + "/" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]
    print(cmds[0])
    with multiprocessing.Pool(processes=20) as pool:
        pool.starmap(process_one, cmds)


def process_one(i, f1, f2):
    if i % 100 == 0:
        print(f"{i}/{all_num} \t", f1)
    os.system(f"ffmpeg  -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--source', type=str, required=False, default=None,
                        help='Path of folder containing the wave files, expected format to be as downloaded '
                             'using the PANNs script, containing balanced_train_segments, eval_segments, '
                             'unbalanced_train_segments folders.')
    parser.add_argument('--out', type=str, required=False, default=None,
                        help='Directory to save out the converted mp3s.')

    args = parser.parse_args()

    source_path = args.source or source_path
    out_path = args.out or out_path
    folders = ['balanced_train_segments', 'eval_segments'] + ["unbalanced_train_segments/" + x for x in sorted(
        os.listdir(source_path + "unbalanced_train_segments/"))]

    print("I will work on these folders:")
    print(folders)

    for fol in folders:
        process_folder(fol)


================================================
FILE: audioset/prepare_scripts/create_h5pymp3_dataset.py
================================================
# %%
import h5py
import pandas as pd
import numpy as np
import csv
import os


# %%
base_dir = "../../audioset_hdf5s/"
balanced_csv= base_dir+ "metadata/balanced_train_segments.csv"
eval_csv= base_dir+ "metadata/eval_segments.csv"
mp3_path = "../../mp3_audio/"


# %%

def read_metadata(csv_path, classes_num, id_to_ix):
    """Read metadata of AudioSet from a csv file.
    source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59
    Args:
      csv_path: str
    Returns:
      meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}
    """

    with open(csv_path, 'r') as fr:
        lines = fr.readlines()
        lines = lines[3:]   # Remove heads

    audios_num = len(lines)
    targets = np.zeros((audios_num, classes_num), dtype=np.bool)
    audio_names = []
 
    for n, line in enumerate(lines):
        items = line.split(', ')
        """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']"""

        audio_name = 'Y{}.mp3'.format(items[0])   # Audios are started with an extra 'Y' when downloading
        label_ids = items[3].split('"')[1].split(',')

        audio_names.append(audio_name)

        # Target
        for id in label_ids:
            ix = id_to_ix[id]
            targets[n, ix] = 1

    meta_dict = {'audio_name': np.array(audio_names), 'target': targets}
    return meta_dict

# Load label
with open(base_dir+'metadata/class_labels_indices.csv', 'r') as f:
    reader = csv.reader(f, delimiter=',')
    lines = list(reader)

labels = []
ids = []    # Each label has a unique id such as "/m/068hy"
for i1 in range(1, len(lines)):
    id = lines[i1][1]
    label = lines[i1][2]
    ids.append(id)
    labels.append(label)

classes_num = len(labels)

lb_to_ix = {label : i for i, label in enumerate(labels)}
ix_to_lb = {i : label for i, label in enumerate(labels)}

id_to_ix = {id : i for i, id in enumerate(ids)}
ix_to_id = {i : id for i, id in enumerate(ids)}

# %%

def check_available(balanced_csv,balanced_audio_path,prefix=None):
    meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix)
    audios_num = len(meta_csv['audio_name'])
    found=0
    notfound=0
    available_files=[]
    available_targets=[]
    if prefix is None:
        prefix = os.path.basename(balanced_csv)[:-4]
    for n in range(audios_num):
        audio_path =  meta_csv['audio_name'][n]
        #print(balanced_audio_path + f"{prefix}/{audio_path}")
        if os.path.isfile(balanced_audio_path + f"{prefix}/{audio_path}" ):
            found+=1
            available_files.append(meta_csv['audio_name'][n])
            available_targets.append(meta_csv['target'][n])
        else:
            notfound+=1
    print(f"Found {found} . not found {notfound}")
    return available_files,available_targets
# %%

# %%

# %%



for read_file,prefix in [(balanced_csv,"balanced_train_segments/"), (eval_csv,"eval_segments/"),]:
    print("now working on ",read_file,prefix)
    #files, y = torch.load(read_file+".pth")
    files, y = check_available(read_file, mp3_path)
    y = np.packbits(y, axis=-1)
    packed_len = y.shape[1]
    print(files[0], "classes: ",packed_len, y.dtype)
    available_size = len(files)
    f = files[0][:-3]+"mp3"
    a = np.fromfile(mp3_path+prefix + "/"+f, dtype='uint8')

    dt = h5py.vlen_dtype(np.dtype('uint8'))
    save_file = prefix.split("/")[0]
    with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf:
        audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
        waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
        target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
        for i,file in enumerate(files):
            if i%1000==0:
                print(f"{i}/{available_size}")
            f = file[:-3] + "mp3"
            a = np.fromfile(mp3_path + prefix  + f, dtype='uint8')
            audio_name[i]=f
            waveform[i] = a
            target[i] = y[i]

    print(a.shape)
    print("Done!" , prefix)


# %%
print("working on unbalanced...")



all_x,all_y = None, None
for idx in  range(41):
    print("working on ",idx)
    tmp_csv = base_dir+ f"metadata/unbalanced_partial_csvs/unbalanced_train_segments_part{idx:02}.csv"
    prefix = f"unbalanced_train_segments/unbalanced_train_segments_part{idx:02}"
    x,y = check_available(tmp_csv,mp3_path,prefix=prefix)
    x = np.array([f"{idx:02}/"+one for one in x])
    y=np.packbits(y, axis=-1)
    print("x,y",x.shape, y.shape)
    if all_x is None:
        all_x = x
        all_y = y
    else:
        all_x = np.concatenate((all_x,x))
        all_y = np.concatenate((all_y,y))
    print(f"done {idx}! all x,y",all_x.shape, all_y.shape)



print("now working on packing  unbalanced")
prefix = "unbalanced_train_segments/unbalanced_train_segments_part"
files = all_x
y = all_y
packed_len = y.shape[1]
print(files[0], "classes: ",packed_len, y.dtype)
available_size = len(files)
f = files[0][:-3]+"mp3"
a = np.fromfile(mp3_path+prefix + f, dtype='uint8')

dt = h5py.vlen_dtype(np.dtype('uint8'))
save_file = prefix.split("/")[0]
with h5py.File(base_dir+ "mp3/" + save_file+"_mp3.hdf", 'w') as hf:
    audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')
    waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)
    target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)
    for i,file in enumerate(files):
        if i%1000==0:
            print(f"{i}/{available_size}")
        f = file[:-3] + "mp3"
        a = np.fromfile(mp3_path + prefix  + f, dtype='uint8')
        audio_name[i]=f
        waveform[i] = a
        target[i] = y[i]

print(a.shape)
print("Done!" , prefix)




================================================
FILE: ba3l/__init__.py
================================================
"""Package info"""

__version__ = "0.0.2"
__author__ = "Koutini et al."
__license__ = "Apache-2.0"
__copyright__ = "Copyright (c) 2019-, %s." % __author__
__homepage__ = "https://github.com//"
__docs__ = (
    "The researcher friendly pytorch environment. Ba3l= sacred+pytorch-lightening."
)
__author_email__ = "first.last@jku.at"


================================================
FILE: ba3l/experiment.py
================================================
import inspect
from importlib import import_module

from ba3l.ingredients.datasets import Datasets
from ba3l.ingredients.models import Models, Model
#from ba3l.trainer import Trainer
from ba3l.util.sacred_logger import SacredLogger
from sacred import Experiment as Sacred_Experiment, Ingredient
from typing import Sequence, Optional, List

from sacred.commandline_options import CLIOption
from sacred.config import CMD
from sacred.host_info import HostInfoGetter
from sacred.utils import PathType, optional_kwargs_decorator
from pytorch_lightning import loggers as pl_loggers
from ba3l.util.functions import get_default_kwargs_dict


def ingredients_recursive_apply(ing, fn):
    fn(ing)
    for kid in ing.ingredients:
        ingredients_recursive_apply(kid, fn)

def config_recursive_apply(conf, fn):
    for k,v in conf.items():
        if isinstance(v, dict):
            config_recursive_apply(v,fn)
        else:
            fn(k,v)


def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):
    loggers = []
    if use_sacred_logger:
        loggers.append( SacredLogger(expr))
    if use_tensorboard_logger:
        loggers.append(pl_loggers.TensorBoardLogger(sacred_logger.name))
    
    return loggers



class Experiment(Sacred_Experiment):
    """
    Main Ba3l Experiment class overrides sacred experiments.
    """

    def __init__(
        self,
        name: Optional[str] = None,
        ingredients: Sequence[Ingredient] = (),
        datasets: Optional[Ingredient] = None,
        models: Optional[Ingredient] = None,
        interactive: bool = False,
        base_dir: Optional[PathType] = None,
        additional_host_info: Optional[List[HostInfoGetter]] = None,
        additional_cli_options: Optional[Sequence[CLIOption]] = None,
        save_git_info: bool = True,
    ):
        """
        Create a new experiment with the given name and optional ingredients. (from Sacred)


        Parameters
        ----------
        name
            Optional name of this experiment, defaults to the filename.
            (Required in interactive mode)

        ingredients : list[sacred.Ingredient], optional
            A list of ingredients to be used with this experiment.

        interactive
            If set to True will allow the experiment to be run in interactive
            mode (e.g. IPython or Jupyter notebooks).
            However, this mode is discouraged since it won't allow storing the
            source-code or reliable reproduction of the runs.

        base_dir
            Optional full path to the base directory of this experiment. This
            will set the scope for automatic source file discovery.

        additional_host_info
            Optional dictionary containing as keys the names of the pieces of
            host info you want to collect, and as
            values the functions collecting those pieces of information.

        save_git_info:
            Optionally save the git commit hash and the git state
            (clean or dirty) for all source files. This requires the GitPython
            package.
        """
        if models is None:
            models = Models.get_instance()
        self.models = models
        if datasets is None:
            datasets = Datasets.get_instance()
        self.datasets = datasets
        if ingredients is None:
            ingredients = []
        ingredients = list(ingredients) + [models, datasets]
        caller_globals = inspect.stack()[1][0].f_globals
        self.last_default_configuration_position = 0
        super().__init__(
            name=name,
            ingredients=ingredients,
            interactive=interactive,
            base_dir=base_dir,
            additional_host_info=additional_host_info,
            additional_cli_options=additional_cli_options,
            save_git_info=save_git_info,
            caller_globals=caller_globals
        )


    def get_run_identifier(self):
        return str(self.current_run.db_identifier) \
            + "_" + str(self.current_run._id)


    def get_dataloaders(self, filter={}):
        results = {}
        for ds in self.datasets.get_datasets(filter):
            results[ds.name] = ds.get_iterator()
        if len(results) == 1:
            for k, v in results.items():
                return v
        return results

    def get_train_dataloaders(self):
        return self.get_dataloaders(dict(train=True))

    def get_val_dataloaders(self):
        return self.get_dataloaders(dict(validate=True))

    def _create_run(
        self,
        command_name=None,
        config_updates=None,
        named_configs=(),
        info=None,
        meta_info=None,
        options=None,
        dry_run=False,
    ):
        if self.current_run is not None:
            # @todo replace with logger
            print("Warning: multiple runs are not yet supported")


        run = super()._create_run(
            command_name,
            config_updates,
            named_configs,
            info,
            meta_info,
            options,
            dry_run=False,
        )
        # self.current_run=run
        # def update_current_run(ing):
        #     ing.current_run = run
        #
        # ingredients_recursive_apply(self, update_current_run)

        return run

    @optional_kwargs_decorator
    def command(
            self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={},
            **extra_args
    ):
        """
        Decorator to define a new Command.

        a command is a function whose parameters are filled automatically by sacred.

        The command can be given a prefix, to restrict its configuration space
        to a subtree. (see ``capture`` for more information)

        A command can be made unobserved (i.e. ignoring all observers) by
        passing the unobserved=True keyword argument.
        :param add_default_args_config: wether to add the default arguments of the function to the config automatically.
        :param function: the function to return a Dataset Object
        :param prefix: sacred configuration prefix
        :param unobserved: sacred unobserved
        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and
         are not stored in the config
        :param extra_args: explicit arguments to be add to the config, you can these to override the function default
        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
        specified by CMD string value. CMD string have special context
        :return:


        """
        add_default_args_config = (not unobserved) and add_default_args_config
        if add_default_args_config:
            self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
        captured_f = self.capture(function, prefix=prefix, static_args=static_args)
        captured_f.unobserved = unobserved
        self.commands[function.__name__] = captured_f
        return captured_f


    def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):
        """
        adds the default parameters of a function to the ingredient config at lowest priority!
        Default args config is meant remove the need to declare all the configurations manually.
        :param f: the function
        """
        # @todo get the doc of the params as well
        config_candidate = {**get_default_kwargs_dict(function), **extra_args}
        # remove "static_args" from config
        for k in static_args:
            config_candidate.pop(k, None)
        # respect the prefix for the added default parameters
        if prefix is not None:
            for pr in prefix.split('.')[::-1]:
                config_candidate={pr: config_candidate}

        self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))
        self.last_default_configuration_position += 1


================================================
FILE: ba3l/ingredients/__init__.py
================================================


================================================
FILE: ba3l/ingredients/datasets.py
================================================
import inspect
import os
from functools import partial

from ba3l.util.functions import get_default_kwargs_dict
from sacred.config import CMD
from .ingredient import Ingredient

from typing import Sequence, Optional, List

from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch

def raise_(ex):
    raise ex


class Dataset(Ingredient):
    """
    The class that annotates a Dateset of Ba3l experiment
    a Dataset can be


    """

    DATASET_STRING_PREFIX = "get_dataset"
    ITER_STRING_PREFIX = "get_iterator"

    def __init__(
        self,
        name: str,
        ingredients: Sequence[Ingredient] = (),
        interactive: bool = False,
        base_dir: Optional[PathType] = None,
        save_git_info: bool = True,
    ):
        """
        Create a new experiment with the given name and optional ingredients.

        Parameters
        ----------
        name
            Optional name of this experiment, defaults to the filename.
            (Required in interactive mode)

        ingredients : list[sacred.Ingredient], optional
            A list of ingredients to be used with this experiment.

        interactive
            If set to True will allow the experiment to be run in interactive
            mode (e.g. IPython or Jupyter notebooks).
            However, this mode is discouraged since it won't allow storing the
            source-code or reliable reproduction of the runs.

        base_dir
            Optional full path to the base directory of this experiment. This
            will set the scope for automatic source file discovery.

        additional_host_info
            Optional dictionary containing as keys the names of the pieces of
            host info you want to collect, and as
            values the functions collecting those pieces of information.

        save_git_info:
            Optionally save the git commit hash and the git state
            (clean or dirty) for all source files. This requires the GitPython
            package.
        """

        caller_globals = inspect.stack()[1][0].f_globals
        if name is None:
            name = "dataset"
        self.name = name.rsplit(".", 1)[-1]
        super().__init__(
            path=name,
            ingredients=ingredients,
            interactive=interactive,
            base_dir=base_dir,
            _caller_globals=caller_globals,
            save_git_info=save_git_info,
        )

        self.get_dataset_command = None
        self.get_dataset_iterator_command = None
        self.current_run = None
        self.get_dataset = lambda: raise_(
            NotImplementedError(
                "Use dataset.dataset_name.dataset to annotate the  "
                "get_dataset function!."
            )
        )
        self.get_iter = lambda: raise_(
            NotImplementedError(
                "Use dataset.dataset_name.iter to annotate the  " "get_iter function!."
            )
        )

    @optional_kwargs_decorator
    def dataset(
        self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
    ):
        """
        Decorator to define a new Dataset.

        The name of the dataset is used to get an instance of the dataset, it will register a command

        Datasets are sacred commands.

        The command can be given a prefix, to restrict its configuration space
        to a subtree. (see ``capture`` for more information)

        A command can be made unobserved (i.e. ignoring all observers) by
        passing the unobserved=True keyword argument.
        :param function: the function to return a Dataset Object
        :param prefix: sacred configuration prefix
        :param unobserved: sacred unobserved
        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and
         are not stored in the config
        :param extra_args: explicit arguments to be add to the config, you can these to override the function default
        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
        specified by CMD string value. CMD string have special context
        :return:


        """
        self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
        captured_f = self.capture(function, prefix=prefix, static_args=static_args)
        captured_f.unobserved = unobserved
        self.commands[Dataset.DATASET_STRING_PREFIX] = captured_f
        self.get_dataset = captured_f
        self.add_config(dataset=CMD("get_dataset"))
        return captured_f

    @optional_kwargs_decorator
    def iter(
        self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
    ):
        """
        Decorator to define a new Iterator.

        The name of the iterator is used to get an instance of the iterator, it will register a command

        iterator are sacred commands.

        The command can be given a prefix, to restrict its configuration space
        to a subtree. (see ``capture`` for more information)

        A command can be made unobserved (i.e. ignoring all observers) by
        passing the unobserved=True keyword argument.
                :param function: the function to return a Dataset Object
        :param prefix: sacred configuration prefix
        :param unobserved: sacred unobserved
        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and
         are not stored in the config
        :param extra_args: explicit arguments to be add to the config, you can these to override the function default
        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
        specified by CMD string value. CMD string have special context
        """
        self.add_default_args_config(function, prefix, extra_args, static_args=static_args)

        captured_f = self.capture(function, prefix=prefix, static_args=static_args)
        captured_f.unobserved = unobserved
        self.commands[Dataset.ITER_STRING_PREFIX] = captured_f
        self.get_iter = captured_f
        return captured_f

    # def get_dataset(self):
    #     assert self.current_run is not None, "Can only be called during a run."
    #     return self.commands[Datasets.DATASET_STRING_PREFIX + name]()
    #     # return self.current_run.get_command_function(
    #     #     self.path + "." + Datasets.DATASET_STRING_PREFIX + name)()
    #     #

    def __getattr__(self, k):
        if k == "iterator":
            return self.__getattribute__("iter")
        if k == "get_iterator":
            return self.__getattribute__("get_iter")
        super().__getattribute__(k)
        # @todo maybe run commands from here after running


class Datasets(Ingredient, Munch):
    """
    The class that encapsulates all the datasets in an experiment


    """

    __instance = None

    @classmethod
    def get_instance(cls):
        if Datasets.__instance is None:
            Datasets.__instance = Datasets()
        return Datasets.__instance

    def __init__(
        self,
        name: Optional[str] = None,
        ingredients: Sequence[Ingredient] = (),
        interactive: bool = False,
        base_dir: Optional[PathType] = None,
        save_git_info: bool = True,
    ):
        """
        Create a new experiment with the given name and optional ingredients.

        Parameters
        ----------
        name
            Optional name of this experiment, defaults to the filename.
            (Required in interactive mode)

        ingredients : list[sacred.Ingredient], optional
            A list of ingredients to be used with this experiment.

        interactive
            If set to True will allow the experiment to be run in interactive
            mode (e.g. IPython or Jupyter notebooks).
            However, this mode is discouraged since it won't allow storing the
            source-code or reliable reproduction of the runs.

        base_dir
            Optional full path to the base directory of this experiment. This
            will set the scope for automatic source file discovery.

        additional_host_info
            Optional dictionary containing as keys the names of the pieces of
            host info you want to collect, and as
            values the functions collecting those pieces of information.

        save_git_info:
            Optionally save the git commit hash and the git state
            (clean or dirty) for all source files. This requires the GitPython
            package.
        """

        caller_globals = inspect.stack()[1][0].f_globals
        if name is None:
            name = "datasets"

        Ingredient.__init__(
            self,
            path=name,
            ingredients=ingredients,
            interactive=interactive,
            base_dir=base_dir,
            _caller_globals=caller_globals,
            save_git_info=save_git_info,
        )

        self.get_datasets_list_command = None
        self.get_dataset_command = None
        self.get_dataset_iterator_command = None
        self.current_run = None
        self.get_dataset = None

        # self.command(get_dataset_iterator_command, unobserved=True)

    def __getattr__(self, k):
        """ Gets key if it exists, otherwise returns the default value."""
        try:
            return Munch.__getattr__(self, k)
        except AttributeError:
            return self.__getitem__(k)

    def __setattr__(self, k, v):
        try:
            # Throws exception if not in prototype chain
            object.__getattribute__(self, k)
        except AttributeError:
            try:
                self[k] = v
            except:
                raise AttributeError(k)
        else:
            object.__setattr__(self, k, v)

    def __getitem__(self, k):
        """ Gets key if it exists, otherwise returns the default value."""
        try:
            return Munch.__getitem__(self, k)
        except KeyError:
            self[k] = Dataset(
                self.path + "." + k,
                base_dir=self.base_dir,
                save_git_info=self.save_git_info,
            )
            assert self
            self.ingredients.append(self[k])
            return self[k]

    def __hash__(self):
        return Ingredient.__hash__(self)

    def get_datasets(self, config_conditions={}, return_datasets_names=False):
        """ Return all the datasets whose configuration matches config_conditions."""
        results = []
        for dataset in self.ingredients:
            all_ok = True
            for cond_k, cond_v in config_conditions.items():
                if (
                    self.current_run.get_config_path_value(dataset.path + "." + cond_k)
                    != cond_v
                ):
                    all_ok = False
                    break
            if all_ok:
                if return_datasets_names:
                    results.append((dataset))
                results.append(dataset)
        return results


================================================
FILE: ba3l/ingredients/ingredient.py
================================================
import inspect
import os
from functools import partial

from ba3l.util.functions import get_default_kwargs_dict
from sacred import Ingredient as sacred_Ingredient

from typing import Sequence, Optional, List

from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch


def raise_(ex):
    raise ex


class Ingredient(sacred_Ingredient):
    """
    The class that annotates a Dateset of Ba3l experiment
    a Dataset can be


    """

    def __init__(
        self,
        path: str,
        ingredients: Sequence[sacred_Ingredient] = (),
        interactive: bool = False,
        _caller_globals: Optional[dict] = None,
        base_dir: Optional[PathType] = None,
        save_git_info: bool = True,
    ):
        """
        The Base Ingredient of all Ba3l ingredients

         Parameters
         ----------
         path
             Optional name of this experiment, defaults to the filename.
             (Required in interactive mode)

         ingredients : list[sacred.Ingredient], optional
             A list of ingredients to be used with this experiment.

         interactive
             If set to True will allow the experiment to be run in interactive
             mode (e.g. IPython or Jupyter notebooks).
             However, this mode is discouraged since it won't allow storing the
             source-code or reliable reproduction of the runs.

         base_dir
             Optional full path to the base directory of this experiment. This
             will set the scope for automatic source file discovery.

         additional_host_info
             Optional dictionary containing as keys the names of the pieces of
             host info you want to collect, and as
             values the functions collecting those pieces of information.

         save_git_info:
             Optionally save the git commit hash and the git state
             (clean or dirty) for all source files. This requires the GitPython
             package.
        """

        _caller_globals = _caller_globals or inspect.stack()[1][0].f_globals
        if path is None:
            path = "Ingredient"

        super().__init__(
            path=path,
            ingredients=ingredients,
            interactive=interactive,
            base_dir=base_dir,
            _caller_globals=_caller_globals,
            save_git_info=save_git_info,
        )

        self.current_run = None
        self.last_default_configuration_position = 0

    def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):
        """
        adds the default parameters of a function to the ingredient config at lowest priority!
        Default args config is meant remove the need to declare all the configurations manually.
        :param f: the function
        """
        # @todo get the doc of the params as well
        config_candidate = {**get_default_kwargs_dict(function), **extra_args}
        # remove "static_args" from config
        for k in static_args:
            config_candidate.pop(k, None)
        if prefix is not None:
            for pr in prefix.split('.')[::-1]:
                config_candidate={pr: config_candidate}
        self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))
        self.last_default_configuration_position += 1

    @optional_kwargs_decorator
    def command(
        self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={}, **extra_args
    ):
        """
        Decorator to define a new Command.

        a command is a function whose parameters are filled automatically by sacred.

        The command can be given a prefix, to restrict its configuration space
        to a subtree. (see ``capture`` for more information)

        A command can be made unobserved (i.e. ignoring all observers) by
        passing the unobserved=True keyword argument.
        :param function: the function to return a Dataset Object
        :param prefix: sacred configuration prefix
        :param unobserved: sacred unobserved
        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and
         are not stored in the config
        :param extra_args: explicit arguments to be add to the config, you can these to override the function default
        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
        specified by CMD string value. CMD string have special context
        :return:


        """
        if add_default_args_config:
            self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
        captured_f = self.capture(function, prefix=prefix, static_args=static_args)
        captured_f.unobserved = unobserved
        self.commands[function.__name__] = captured_f
        return captured_f


================================================
FILE: ba3l/ingredients/models.py
================================================
import inspect
import os

from .ingredient import Ingredient

from typing import Sequence, Optional, List

from sacred.utils import PathType

import inspect
import os
from functools import partial

from ba3l.util.functions import get_default_kwargs_dict
from sacred.config import CMD
from .ingredient import Ingredient

from typing import Sequence, Optional, List

from sacred.utils import PathType, optional_kwargs_decorator
from munch import DefaultFactoryMunch, Munch


class Models(Ingredient):
    """
    The class that annotates the models of Ba3l experiment


    """

    __instance = None

    @classmethod
    def get_instance(cls):
        if Models.__instance is None:
            Models.__instance = Models()
        return Models.__instance

    def __init__(
        self,
        name: Optional[str] = None,
        ingredients: Sequence[Ingredient] = (),
        interactive: bool = False,
        base_dir: Optional[PathType] = None,
        save_git_info: bool = True,
    ):
        """
        Create a new experiment with the given name and optional ingredients.

        Parameters
        ----------
        name
            Optional name of this experiment, defaults to the filename.
            (Required in interactive mode)

        ingredients : list[sacred.Ingredient], optional
            A list of ingredients to be used with this experiment.

        interactive
            If set to True will allow the experiment to be run in interactive
            mode (e.g. IPython or Jupyter notebooks).
            However, this mode is discouraged since it won't allow storing the
            source-code or reliable reproduction of the runs.

        base_dir
            Optional full path to the base directory of this experiment. This
            will set the scope for automatic source file discovery.

        additional_host_info
            Optional dictionary containing as keys the names of the pieces of
            host info you want to collect, and as
            values the functions collecting those pieces of information.

        save_git_info:
            Optionally save the git commit hash and the git state
            (clean or dirty) for all source files. This requires the GitPython
            package.
        """

        caller_globals = inspect.stack()[1][0].f_globals
        if name is None:
            name = "models"

        super().__init__(
            path=name,
            ingredients=ingredients,
            interactive=interactive,
            base_dir=base_dir,
            _caller_globals=caller_globals,
            save_git_info=save_git_info,
        )

        self.get_models_command = None
        self.current_run = None
        # self.command(print_config, unobserved=True)


def raise_(ex):
    raise ex


class Model(Ingredient):
    """
    The class that annotates a Dateset of Ba3l experiment
    a Dataset can be


    """

    MODEL_STRING_PREFIX = "get_instance"

    def __init__(
        self,
        name: str,
        ingredients: Sequence[Ingredient] = (),
        interactive: bool = False,
        base_dir: Optional[PathType] = None,
        save_git_info: bool = True,
    ):
        """
        Create a new experiment with the given name and optional ingredients.

        Parameters
        ----------
        name
            Optional name of this experiment, defaults to the filename.
            (Required in interactive mode)

        ingredients : list[sacred.Ingredient], optional
            A list of ingredients to be used with this experiment.

        interactive
            If set to True will allow the experiment to be run in interactive
            mode (e.g. IPython or Jupyter notebooks).
            However, this mode is discouraged since it won't allow storing the
            source-code or reliable reproduction of the runs.

        base_dir
            Optional full path to the base directory of this experiment. This
            will set the scope for automatic source file discovery.

        additional_host_info
            Optional dictionary containing as keys the names of the pieces of
            host info you want to collect, and as
            values the functions collecting those pieces of information.

        save_git_info:
            Optionally save the git commit hash and the git state
            (clean or dirty) for all source files. This requires the GitPython
            package.
        """

        caller_globals = inspect.stack()[1][0].f_globals
        if name is None:
            name = "model"
        self.name = name.rsplit(".", 1)[-1]
        super().__init__(
            path=name,
            ingredients=ingredients,
            interactive=interactive,
            base_dir=base_dir,
            _caller_globals=caller_globals,
            save_git_info=save_git_info,
        )

        self.get_instance_command = None
        self.current_run = None
        self.get_instance = lambda: raise_(
            NotImplementedError(
                "Use dataset.dataset_name.dataset to annotate the  "
                "get_dataset function!."
            )
        )

    @optional_kwargs_decorator
    def instance(
        self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args
    ):
        """
        Decorator to define a new model.

        The name of the model is used to get an instance of the model, it will register a command


        The command can be given a prefix, to restrict its configuration space
        to a subtree. (see ``capture`` for more information)

        A command can be made unobserved (i.e. ignoring all observers) by
        passing the unobserved=True keyword argument.
        :param function: the function to return a Dataset Object
        :param prefix: sacred configuration prefix
        :param unobserved: sacred unobserved
        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and
         are not stored in the config
        :param extra_args: explicit arguments to be add to the config, you can these to override the function default
        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command
        specified by CMD string value. CMD string have special context
        :return:


        """
        self.add_default_args_config(function, prefix, extra_args, static_args=static_args)
        captured_f = self.capture(function, prefix=prefix, static_args=static_args)
        captured_f.unobserved = unobserved
        self.commands[Model.MODEL_STRING_PREFIX] = captured_f
        self.get_instance = captured_f
        self.add_config(get_instance=CMD("get_instance"))
        return captured_f

    def __getattr__(self, k):
        if k == "get_instance":
            return self.__getattribute__("get_instance")
        super().__getattribute__(k)
        # @todo maybe run commands from here after running


================================================
FILE: ba3l/ingredients/trainer.py
================================================
import inspect
import os

from ba3l.util.functions import get_default_kwargs_dict
from .datasets import Datasets, raise_
from .models import Models
from sacred import Ingredient

from typing import Sequence, Optional, List

from sacred.utils import PathType, optional_kwargs_decorator


class Trainer(Ingredient):
    """
    The class that annotates the main Trainer of Ba3l experiment


    """

    TRAINER_STRING_PREFIX = "get_trainer"


================================================
FILE: ba3l/module.py
================================================
import pytorch_lightning as pl

import warnings
from abc import ABC

import torch.distributed as dist
from munch import DefaultMunch

try:
    # loading for pyTorch 1.3
    from torch.utils.data import IterableDataset
except ImportError:
    # loading for pyTorch 1.1
    import torch

    warnings.warn(
        "Your version of pyTorch %s does not support `IterableDataset`,"
        " please upgrade to 1.2+" % torch.__version__,
        ImportWarning,
    )
    EXIST_ITER_DATASET = False
else:
    EXIST_ITER_DATASET = True

try:
    from apex import amp

    APEX_AVAILABLE = True
except ImportError:
    APEX_AVAILABLE = False


class Ba3lModule(pl.LightningModule):
    def __init__(self, experiment):
        super(Ba3lModule, self).__init__()
        self.experiment = experiment
        self.run =  experiment.current_run
        self.config = DefaultMunch.fromDict(experiment.current_run.config)
        for key,model in experiment.current_run.config['models'].items():
            setattr(self, key, experiment.current_run.get_command_function("models."+key+"."+model['instance_cmd'])())
        self.save_hyperparameters(self.config)
        



================================================
FILE: ba3l/plutils/__init__.py
================================================


================================================
FILE: ba3l/plutils/lr_monitor.py
================================================
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""

Learning Rate Monitor
=====================

Monitor and logs learning rate for lr schedulers during training.

"""

from typing import Dict, List, Optional

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LearningRateMonitor(Callback):
    r"""
    Automatically monitor and logs learning rate for learning rate schedulers during training.

    Args:
        logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers
            at the same interval, set to ``None`` to log at individual interval
            according to the ``interval`` key of each scheduler. Defaults to ``None``.
        log_momentum: option to also log the momentum values of the optimizer, if the optimizer
            has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.

    Raises:
        MisconfigurationException:
            If ``logging_interval`` is none of ``"step"``, ``"epoch"``, or ``None``.

    Example::

        >>> from pytorch_lightning import Trainer
        >>> from pytorch_lightning.callbacks import LearningRateMonitor
        >>> lr_monitor = LearningRateMonitor(logging_interval='step')
        >>> trainer = Trainer(callbacks=[lr_monitor])

    Logging names are automatically determined based on optimizer class name.
    In case of multiple optimizers of same type, they will be named ``Adam``,
    ``Adam-1`` etc. If a optimizer has multiple parameter groups they will
    be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
    ``name`` keyword in the construction of the learning rate schdulers

    Example::

        def configure_optimizer(self):
            optimizer = torch.optim.Adam(...)
            lr_scheduler = {
                'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
                'name': 'my_logging_name'
            }
            return [optimizer], [lr_scheduler]

    """

    def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
        if logging_interval not in (None, 'step', 'epoch'):
            raise MisconfigurationException('logging_interval should be `step` or `epoch` or `None`.')

        self.logging_interval = logging_interval
        self.log_momentum = log_momentum
        self.lrs = None
        self.lr_sch_names = []

    def on_train_start(self, trainer, *args, **kwargs):
        """
        Called before training, determines unique names for all lr
        schedulers in the case of multiple of the same type or in
        the case of multiple parameter groups

        Raises:
            MisconfigurationException:
                If ``Trainer`` has no ``logger``.
        """
        if not trainer.logger:
            raise MisconfigurationException(
                'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.'
            )

        if not trainer.lr_schedulers:
            rank_zero_warn(
                'You are using `LearningRateMonitor` callback with models that'
                ' have no learning rate schedulers. Please see documentation'
                ' for `configure_optimizers` method.', RuntimeWarning
            )

        if self.log_momentum:

            def _check_no_key(key):
                return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers)

            if _check_no_key('momentum') and _check_no_key('betas'):
                rank_zero_warn(
                    "You have set log_momentum=True, but some optimizers do not"
                    " have momentum. This will log a value 0 for the momentum.", RuntimeWarning
                )

        # Find names for schedulers
        names = self._find_names(trainer.lr_schedulers)

        # Initialize for storing values
        self.lrs = {name: [] for name in names}
        self.last_momentum_values = {name + "-momentum": None for name in names}

    def on_train_batch_start(self, trainer, *args, **kwargs):
        if not self._should_log(trainer):
            return

        if self.logging_interval != 'epoch':
            interval = 'step' if self.logging_interval is None else 'any'
            latest_stat = self._extract_stats(trainer, interval)

            if latest_stat:
                trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

    def on_train_epoch_start(self, trainer, *args, **kwargs):
        if self.logging_interval != 'step':
            interval = 'epoch' if self.logging_interval is None else 'any'
            latest_stat = self._extract_stats(trainer, interval)

            if latest_stat:
                trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)

    def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
        latest_stat = {}

        for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
            if scheduler['interval'] == interval or interval == 'any':
                opt = scheduler['scheduler'].optimizer
                param_groups = opt.param_groups
                use_betas = 'betas' in opt.defaults

                for i, pg in enumerate(param_groups):
                    suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
                    lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
                    latest_stat.update(lr)
                    momentum = self._extract_momentum(
                        param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
                    )
                    latest_stat.update(momentum)

        return latest_stat

    def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
        lr = param_group.get('lr')
        self.lrs[name].append(lr)
        return {name: lr}

    def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
        if not self.log_momentum:
            return {}

        momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
        self.last_momentum_values[name] = momentum
        return {name: momentum}

    def _find_names(self, lr_schedulers) -> List[str]:
        # Create uniqe names in the case we have multiple of the same learning
        # rate schduler + multiple parameter groups
        names = []
        for scheduler in lr_schedulers:
            sch = scheduler['scheduler']
            if scheduler['name'] is not None:
                name = scheduler['name']
            else:
                opt_name = 'lr-' + sch.optimizer.__class__.__name__
                i, name = 1, opt_name

                # Multiple schduler of the same type
                while True:
                    if name not in names:
                        break
                    i, name = i + 1, f'{opt_name}-{i}'

            # Multiple param groups for the same schduler
            param_groups = sch.optimizer.param_groups

            if len(param_groups) != 1:
                for i, pg in enumerate(param_groups):
                    temp = f'{name}/pg{i + 1}'
                    names.append(temp)
            else:
                names.append(name)

            self.lr_sch_names.append(name)

        return names

    @staticmethod
    def _should_log(trainer) -> bool:
        should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)

        return should_log


================================================
FILE: ba3l/plutils/progress_bar.py
================================================

import importlib
import io
import os
import sys

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from typing import Optional, Union


from pytorch_lightning.callbacks import Callback,ProgressBarBase , ProgressBar as PlProgressBar
from pytorch_lightning.callbacks.progress import tqdm



class ProgressBar(PlProgressBar):
    r"""
    This is the default progress bar used by Lightning. It prints to `stdout` using the
    :mod:`tqdm` package and shows up to four different bars:

    - **sanity check progress:** the progress during the sanity check run
    - **main progress:** shows training + validation progress combined. It also accounts for
      multiple validation runs during training when
      :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
    - **validation progress:** only visible during validation;
      shows total progress over all validation datasets.
    - **test progress:** only active when testing; shows total progress over all test datasets.

    For infinite datasets, the progress bar never ends.

    If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override
    specific methods of the callback class and pass your custom implementation to the
    :class:`~pytorch_lightning.trainer.trainer.Trainer`:

    Example::

        class LitProgressBar(ProgressBar):

            def init_validation_tqdm(self):
                bar = super().init_validation_tqdm()
                bar.set_description('running validation ...')
                return bar

        bar = LitProgressBar()
        trainer = Trainer(callbacks=[bar])

    Args:
        refresh_rate:
            Determines at which rate (in number of batches) the progress bars get updated.
            Set it to ``0`` to disable the display. By default, the
            :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress
            bar and sets the refresh rate to the value provided to the
            :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the
            :class:`~pytorch_lightning.trainer.trainer.Trainer`.
        process_position:
            Set this to a value greater than ``0`` to offset the progress bars by this many lines.
            This is useful when you have progress bars defined elsewhere and want to show all of them
            together. This corresponds to
            :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the
            :class:`~pytorch_lightning.trainer.trainer.Trainer`.

    """


    def init_validation_tqdm(self) -> tqdm:
        """ Override this to customize the tqdm bar for validation. """
        # The main progress bar doesn't exist in `trainer.validate()`
        has_main_bar = self.main_progress_bar is not None
        has_main_bar = False
        bar = tqdm(
            desc='Validating',
            position=(2 * self.process_position + has_main_bar),
            disable=self.is_disabled,
            leave=False,
            dynamic_ncols=True,
            file=sys.stdout
        )
        return bar



    def on_epoch_start(self, trainer, pl_module):
        ProgressBarBase.on_epoch_start(self, trainer, pl_module)
        self.main_progress_bar = self.init_train_tqdm()
        total_train_batches = self.total_train_batches
        total_val_batches = self.total_val_batches
        if total_train_batches != float('inf'):
            # val can be checked multiple times per epoch
            val_checks_per_epoch = total_train_batches // trainer.val_check_batch
            total_val_batches = total_val_batches * val_checks_per_epoch
        total_val_batches = 0
        total_batches = total_train_batches + total_val_batches
        reset(self.main_progress_bar, total_batches)
        self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        ProgressBarBase.on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
        if self._should_update(self.train_batch_idx, self.total_train_batches):
            self._update_bar(self.main_progress_bar)
            self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

    def on_validation_start(self, trainer, pl_module):
        ProgressBarBase.on_validation_start(self, trainer, pl_module)
        if trainer.sanity_checking:
            reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
        else:
            if self.main_progress_bar is not None:
                self.main_progress_bar.close()
            self.val_progress_bar = self.init_validation_tqdm()
            reset(self.val_progress_bar, self.total_val_batches)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        ProgressBarBase.on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
        if self._should_update(self.val_batch_idx, self.total_val_batches):
            self._update_bar(self.val_progress_bar)
            #self._update_bar(self.main_progress_bar)

    def on_validation_end(self, trainer, pl_module):
        ProgressBarBase.on_validation_end(self, trainer, pl_module)
        if self.main_progress_bar is not None:
            self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
        self.val_progress_bar.close()



def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
    """ The tqdm doesn't support inf values. We have to convert it to None. """
    if x == float('inf'):
        return None
    return x


def reset(bar: tqdm, total: Optional[int] = None) -> None:
    """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """
    if not bar.disable:
        bar.reset(total=convert_inf(total))


================================================
FILE: ba3l/util/__init__.py
================================================


================================================
FILE: ba3l/util/functions.py
================================================
import inspect
from collections import OrderedDict


def get_default_kwargs_dict(f):
    sig = inspect.signature(f)
    return OrderedDict(
        [
            (p.name, p.default)
            for p in sig.parameters.values()
            if p.default != inspect._empty
        ]
    )


================================================
FILE: ba3l/util/sacred_logger.py
================================================
from pytorch_lightning.utilities import rank_zero_only
try:
    from pytorch_lightning.loggers import Logger as LightningLoggerBase
    from pytorch_lightning.loggers.logger import rank_zero_experiment
except ImportError:
    from pytorch_lightning.loggers import LightningLoggerBase
    from pytorch_lightning.loggers.base import rank_zero_experiment

from warnings import warn


from logging import getLogger

try:
    import sacred
except ImportError:
    raise ImportError("Missing sacred package.  Run `pip install sacred`")


logger = getLogger(__name__)


class SacredLogger(LightningLoggerBase):
    def __init__(self, sacred_experiment):
        """Initialize a sacred logger.
        :param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers
        already appended.
        source: https://github.com/expectopatronum/pytorch-lightning/blob/9fcb238ec03e3f0b0378fd058119f1563a11650c/pytorch_lightning/logging/sacred.py
        """
        super().__init__()
        self.sacred_experiment = sacred_experiment
        self.experiment_name = sacred_experiment.path
        self._run_id = None
        warn('SacredLogger is deprecated', DeprecationWarning, stacklevel=2)

    @property
    def experiment(self):
        return self.sacred_experiment

    @property
    def run_id(self):
        if self._run_id is not None:
            return self._run_id
        self._run_id = self.sacred_experiment.get_run_identifier()
        return self._run_id

    @rank_zero_only
    def log_hyperparams(self, params):
        # probably not needed bc. it is dealt with by sacred
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step=None):
        for k, v in metrics.items():
            if isinstance(v, str):
                logger.warning(f"Discarding metric with string value {k}={v}")
                continue
            self.experiment.log_scalar(k, v, step)

    @property
    def name(self):
        return self.experiment_name

    @property
    def version(self):
        return self.run_id

    @rank_zero_only
    def save(self):
        # Optional. Any code necessary to save logger data goes here
        # If you implement this, remember to call `super().save()`
        # at the start of the method (important for aggregation of metrics)
        super().save()

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass


================================================
FILE: config_updates.py
================================================
from sacred.config_helpers import DynamicIngredient, CMD


def add_configs(ex):
    '''
    This functions add generic configuration for the experiments, such as mix-up, architectures, etc...
    @param ex: Ba3l Experiment
    @return:
    '''

    @ex.named_config
    def nomixup():
        'Don\'t apply mix-up (spectrogram level).'
        use_mixup = False
        mixup_alpha = 0.3

    @ex.named_config
    def mixup():
        ' Apply mix-up (spectrogram level).'
        use_mixup = True
        mixup_alpha = 0.3

    @ex.named_config
    def mini_train():
        'limit training/validation to 5 batches for debbuging.'
        trainer = dict(limit_train_batches=5, limit_val_batches=5)

    @ex.named_config
    def passt():
        'use PaSST model'
        models = {
            "net": DynamicIngredient("models.passt.model_ing")
        }

    @ex.named_config
    def passt_s_20sec():
        'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 20 seconds'
        # python ex_audioset.py evaluate_only with passt_s_ap476
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_20sec_p16_s10_ap474", fstride=10,
                                     tstride=10, input_tdim=2000)
        }
        basedataset = dict(clip_length=20)

    @ex.named_config
    def passt_s_30sec():
        'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 30 seconds'
        # python ex_audioset.py evaluate_only with passt_s_ap476
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_30sec_p16_s10_ap473", fstride=10,
                                     tstride=10, input_tdim=3000)
        }
        basedataset = dict(clip_length=20)

    @ex.named_config
    def passt_s_ap476():
        'use PaSST model pretrained on Audioset (with SWA) ap=476'
        # python ex_audioset.py evaluate_only with passt_s_ap476
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap476", fstride=10,
                                     tstride=10)
        }

    @ex.named_config
    def passt_s_ap4763():
        'use PaSST model pretrained on Audioset (with SWA) ap=4763'
        # test with: python ex_audioset.py evaluate_only with passt_s_ap4763
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap4763", fstride=10,
                                     tstride=10)
        }

    @ex.named_config
    def passt_s_ap472():
        'use PaSST model pretrained on Audioset (no SWA) ap=472'
        # test with: python ex_audioset.py evaluate_only with passt_s_ap472
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_128_ap472", fstride=10,
                                     tstride=10)
        }

    @ex.named_config
    def passt_s_p16_s16_128_ap468():
        'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap'
        # test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s16_128_ap468", fstride=16,
                                     tstride=16)
        }

    @ex.named_config
    def passt_s_swa_p16_s16_128_ap473():
        'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap'
        # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s16_128_ap473", fstride=16,
                                     tstride=16)
        }

    @ex.named_config
    def passt_s_swa_p16_s14_128_ap471():
        'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 '
        # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s14_128_ap471", fstride=14,
                                     tstride=14)
        }

    @ex.named_config
    def passt_s_p16_s14_128_ap469():
        'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 '
        # test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s14_128_ap469", fstride=14,
                                     tstride=14)
        }

    @ex.named_config
    def passt_s_swa_p16_s12_128_ap473():
        'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 '
        # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s12_128_ap473", fstride=12,
                                     tstride=12)
        }

    @ex.named_config
    def passt_s_p16_s12_128_ap470():
        'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 '
        # test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s12_128_ap470", fstride=12,
                                     tstride=12)
        }

    @ex.named_config
    def ensemble_s10():
        'use ensemble of PaSST models pretrained on Audioset  with S10 mAP=.4864'
        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s10
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s10", fstride=None,
                                     tstride=None, instance_cmd="get_ensemble_model",
                                     # don't call get_model but rather get_ensemble_model
                                     arch_list=[
                                         ("passt_s_swa_p16_128_ap476", 10, 10),
                                         ("passt_s_swa_p16_128_ap4761", 10, 10),
                                         ("passt_s_p16_128_ap472", 10, 10),
                                     ]
                                     )
        }

    @ex.named_config
    def ensemble_many():
        'use ensemble of PaSST models pretrained on Audioset  with different strides mAP=.4956'
        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
                                     tstride=None, instance_cmd="get_ensemble_model",
                                     # don't call get_model but rather get_ensemble_model
                                     arch_list=[
                                         ("passt_s_swa_p16_128_ap476", 10, 10),
                                         ("passt_s_swa_p16_128_ap4761", 10, 10),
                                         ("passt_s_p16_128_ap472", 10, 10),
                                         ("passt_s_p16_s12_128_ap470", 12, 12),
                                         ("passt_s_swa_p16_s12_128_ap473", 12, 12),
                                         ("passt_s_p16_s14_128_ap469", 14, 14),
                                         ("passt_s_swa_p16_s14_128_ap471", 14, 14),
                                         ("passt_s_swa_p16_s16_128_ap473", 16, 16),
                                         ("passt_s_p16_s16_128_ap468", 16, 16),
                                     ]
                                     )
        }

    @ex.named_config
    def ensemble_4():
        'use ensemble of PaSST models pretrained on Audioset  with different strides mAP=.4926'
        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
                                     tstride=None, instance_cmd="get_ensemble_model",
                                     # don't call get_model but rather get_ensemble_model
                                     arch_list=[
                                         ("passt_s_swa_p16_128_ap476", 10, 10),
                                         ("passt_s_swa_p16_s12_128_ap473", 12, 12),
                                         ("passt_s_swa_p16_s14_128_ap471", 14, 14),
                                         ("passt_s_swa_p16_s16_128_ap473", 16, 16),
                                     ]
                                     )
        }

    @ex.named_config
    def ensemble_5():
        'use ensemble of PaSST models pretrained on Audioset  with different strides mAP=.49459'
        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
                                     tstride=None, instance_cmd="get_ensemble_model",
                                     # don't call get_model but rather get_ensemble_model
                                     arch_list=[
                                         ("passt_s_swa_p16_128_ap476", 10, 10),
                                         ("passt_s_swa_p16_128_ap4761", 10, 10),
                                         ("passt_s_swa_p16_s12_128_ap473", 12, 12),
                                         ("passt_s_swa_p16_s14_128_ap471", 14, 14),
                                         ("passt_s_swa_p16_s16_128_ap473", 16, 16),
                                     ]
                                     )
        }

    @ex.named_config
    def ensemble_s16_14():
        'use ensemble of two PaSST models pretrained on Audioset  with stride 16 and 14 mAP=.48579'
        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s16_14
        models = {
            "net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s16", fstride=None,
                                     tstride=None, instance_cmd="get_ensemble_model",
                                     # don't call get_model but rather get_ensemble_model
                                     arch_list=[
                                         ("passt_s_swa_p16_s14_128_ap471", 14, 14),
                                         ("passt_s_swa_p16_s16_128_ap473", 16, 16),
                                     ]
                                     )
        }

    @ex.named_config
    def dynamic_roll():
        # dynamically roll the spectrograms/waveforms
        # updates the dataset config
        basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000)
                           )

    # extra commands

    @ex.command
    def test_loaders_train_speed():
        # test how fast data is being loaded from the data loaders.
        itr = ex.datasets.training.get_iter()
        import time
        start = time.time()
        print("hello")
        for i, b in enumerate(itr):
            if i % 20 == 0:
                print(f"{i}/{len(itr)}", end="\r")
        end = time.time()
        print("totoal time:", end - start)
        start = time.time()
        print("retry:")
        for i, b in enumerate(itr):
            if i % 20 == 0:
                print(f"{i}/{len(itr)}", end="\r")
        end = time.time()
        print("totoal time:", end - start)


================================================
FILE: environment.yml
================================================
name: ba3l
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1
  - _openmp_mutex=5.1
  - ca-certificates=2022.10.11
  - certifi=2022.12.7
  - ld_impl_linux-64=2.38
  - libffi=3.4.2
  - libgcc-ng=11.2.0
  - libgomp=11.2.0
  - libstdcxx-ng=11.2.0
  - ncurses=6.3
  - openssl=1.1.1s
  - pip=22.3.1
  - python=3.8.16
  - readline=8.2
  - setuptools=65.6.3
  - sqlite=3.40.1
  - tk=8.6.12
  - wheel=0.37.1
  - xz=5.2.10
  - zlib=1.2.13
  - pip:
    - absl-py==1.4.0
    - aiohttp==3.8.4
    - aiosignal==1.3.1
    - appdirs==1.4.4
    - async-timeout==4.0.2
    - attrs==22.2.0
    - audioread==3.0.0
    - av==10.0.0
    - cachetools==5.3.0
    - cffi==1.15.1
    - charset-normalizer==3.0.1
    - colorama==0.4.6
    - decorator==5.1.1
    - docopt==0.6.2
    - frozenlist==1.3.3
    - fsspec==2023.3.0
    - future==0.18.3
    - gitdb==4.0.10
    - gitpython==3.1.31
    - google-auth==2.17.0
    - google-auth-oauthlib==0.4.6
    - grpcio==1.53.0
    - h5py==3.8.0
    - idna==3.4
    - imageio==2.27.0
    - importlib-metadata==6.1.0
    - joblib==1.2.0
    - jsonpickle==3.0.1
    - kk-sacred==0.8.4
    - lazy-loader==0.2
    - librosa==0.10.0.post2
    - llvmlite==0.39.1
    - markdown==3.4.3
    - markupsafe==2.1.2
    - msgpack==1.0.5
    - multidict==6.0.4
    - munch==2.5.0
    - numba==0.56.4
    - numpy==1.23.5
    - nvidia-cublas-cu11==11.10.3.66
    - nvidia-cuda-nvrtc-cu11==11.7.99
    - nvidia-cuda-runtime-cu11==11.7.99
    - nvidia-cudnn-cu11==8.5.0.96
    - oauthlib==3.2.2
    - packaging==23.0
    - pandas==1.5.3
    - pillow==9.4.0
    - pooch==1.6.0
    - protobuf==4.22.1
    - py-cpuinfo==9.0.0
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pycparser==2.21
    - python-dateutil==2.8.2
    - pytorch-lightning==1.3.0.dev0
    - pytz==2023.3
    - pyyaml==6.0
    - requests==2.28.2
    - requests-oauthlib==1.3.1
    - rsa==4.9
    - scikit-learn==1.2.2
    - scipy==1.10.1
    - six==1.16.0
    - smmap==5.0.0
    - soundfile==0.12.1
    - soxr==0.3.4
    - tensorboard==2.12.0
    - tensorboard-data-server==0.7.0
    - tensorboard-plugin-wit==1.8.1
    - test-tube==0.7.5
    - threadpoolctl==3.1.0
    - timm==0.4.12
    - torch==1.13.1
    - torchaudio==0.13.1
    - torchmetrics==0.2.0
    - torchvision==0.14.1
    - tqdm==4.65.0
    - typing-extensions==4.4.0
    - urllib3==1.26.14
    - werkzeug==2.2.3
    - wrapt==1.15.0
    - yarl==1.8.2
    - zipp==3.15.0


================================================
FILE: environment_old_2021.yml
================================================
name: ba3l
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1
  - _openmp_mutex=4.5
  - _pytorch_select=0.1
  - appdirs=1.4.4
  - audioread=2.1.9
  - blas=1.0
  - brotlipy=0.7.0
  - bzip2=1.0.8
  - c-ares=1.17.1
  - ca-certificates=2020.12.5
  - cached-property=1.5.2
  - cached_property=1.5.2
  - certifi=2020.12.5
  - cffi=1.14.5
  - chardet=4.0.0
  - colorama=0.4.4
  - cryptography=3.4.6
  - cycler=0.10.0
  - decorator=4.4.2
  - docopt=0.6.2
  - ffmpeg=4.3.1
  - freetype=2.10.4
  - gettext=0.19.8.1
  - gitdb=4.0.5
  - gitpython=3.1.14
  - gmp=6.2.1
  - gnutls=3.6.13
  - h5py=3.1.0
  - hdf5=1.10.6
  - idna=2.10
  - importlib-metadata=3.7.3
  - importlib_metadata=3.7.3
  - intel-openmp=2020.2
  - joblib=1.0.1
  - jpeg=9d
  - jsonpickle=1.4.1
  - kiwisolver=1.3.1
  - krb5=1.17.2
  - lame=3.100
  - lcms2=2.12
  - ld_impl_linux-64=2.35.1
  - libblas=3.9.0
  - libcblas=3.9.0
  - libcurl=7.75.0
  - libedit=3.1.20191231
  - libev=4.33
  - libffi=3.3
  - libflac=1.3.3
  - libgcc-ng=9.3.0
  - libgfortran-ng=9.3.0
  - libgfortran5=9.3.0
  - libgomp=9.3.0
  - liblapack=3.9.0
  - libllvm10=10.0.1
  - libnghttp2=1.43.0
  - libogg=1.3.4
  - libopenblas=0.3.12
  - libopus=1.3.1
  - libpng=1.6.37
  - librosa=0.8.0
  - libsndfile=1.0.31
  - libssh2=1.9.0
  - libstdcxx-ng=9.3.0
  - libtiff=4.2.0
  - libvorbis=1.3.7
  - libwebp-base=1.2.0
  - llvm-openmp=11.1.0
  - llvmlite=0.36.0
  - lz4-c=1.9.3
  - matplotlib-base=3.3.4
  - mkl=2020.2
  - mkl-service=2.3.0
  - munch=2.5.0
  - ncurses=6.2
  - nettle=3.6
  - ninja=1.10.2
  - numba=0.53.0
  - numpy=1.20.1
  - olefile=0.46
  - openblas=0.3.12
  - openh264=2.1.1
  - openssl=1.1.1k
  - packaging=20.9
  - pandas=1.2.3
  - pillow=8.1.2
  - pip=21.0.1
  - pooch=1.3.0
  - py-cpuinfo=7.0.0
  - pycparser=2.20
  - pyopenssl=20.0.1
  - pyparsing=2.4.7
  - pysocks=1.7.1
  - pysoundfile=0.10.3.post1
  - python=3.7.10
  - python-dateutil=2.8.1
  - python_abi=3.7
  - pytz=2021.1
  - readline=8.0
  - requests=2.25.1
  - resampy=0.2.2
  - scikit-learn=0.24.1
  - scipy=1.6.1
  - setuptools=49.6.0
  - six=1.15.0
  - smmap=3.0.5
  - sqlite=3.34.0
  - threadpoolctl=2.1.0
  - tk=8.6.10
  - tornado=6.1
  - typing_extensions=3.7.4.3
  - urllib3=1.26.4
  - wrapt=1.12.1
  - x264=1!161.3030
  - xz=5.2.5
  - zipp=3.4.1
  - zlib=1.2.11
  - zstd=1.4.9
  - pip:
    - absl-py==0.12.0
    - aiohttp==3.7.4.post0
    - async-timeout==3.0.1
    - attrs==20.3.0
    - av==8.0.3
    - black==20.8b1
    - cachetools==4.2.1
    - click==7.1.2
    - einops==0.3.0
    - fsspec==0.8.7
    - future==0.18.2
    - google-auth==1.28.0
    - google-auth-oauthlib==0.4.3
    - gpuinfo==1.0.0a7
    - grpcio==1.36.1
    - imageio==2.9.0
    - jedi==0.18.0
    - libtmux==0.8.5
    - markdown==3.3.4
    - multidict==5.1.0
    - mypy-extensions==0.4.3
    - oauthlib==3.1.0
    - parso==0.8.2
    - pathspec==0.8.1
    - prompt-toolkit==3.0.18
    - protobuf==3.15.6
    - ptpython==3.0.17
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pydub==0.25.1
    - pygments==2.8.1
    - pymongo==3.11.3
    - pyyaml==5.3.1
    - regex==2021.4.4
    - requests-oauthlib==1.3.0
    - rsa==4.7.2
    - tensorboard==2.4.1
    - tensorboard-plugin-wit==1.8.0
    - test-tube==0.7.5
    - timm==0.4.12
    - toml==0.10.2
    - torch==1.8.1+cu111
    - torchaudio==0.8.1
    - torchmetrics==0.2.0
    - torchvision==0.6.0
    - tqdm==4.59.0
    - typed-ast==1.4.3
    - wcwidth==0.2.5
    - werkzeug==1.0.1
    - wheel==0.36.2
    - yarl==1.6.3



================================================
FILE: esc50/README.md
================================================
# Experiments on ESC-50 Environmental Sound Classification
[ESC-50](https://github.com/karolpiczak/ESC-50) consist of 2000 5-second recordings to be classified to 50 semantical classes.

## Setting up the fine-tuning experiments
- Download the prerpocessed (resampled) dataset [esc50.zip](https://github.com/kkoutini/PaSST/releases/download/v.0.0.6/esc50.zip) from the [releases page](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6) and 
unpack the zip file into a directory (the default path is `./audioset_hdf5s/`). The `base_dir` config in the dataset file ([here](https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py#L35)) should point to the extracted contents of the dataset zip file.
- Running the experiments using the common configurations (similar to Audioset)
```shell
python3 ex_esc50.py with models.net.s_patchout_t=10 models.net.s_patchout_f=5  basedataset.fold=1 -p
```
## Pre-trained models

Pre-trained models on ESC-50 can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6). 

In order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 :
```python
from hear21passt.base import get_basic_model,get_model_passt
import torch
# model wrapper, includes Melspectrogram and the default transformer
model = get_basic_model(mode="logits")
# replace the transformer with one that outputs 50 classes
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)

# load the pre-trained model state dict
state_dict = torch.load('/home/khaled/esc50-passt-s-n-f128-p16-s10-fold1-acc.967.pt')
# load the weights into the transformer
model.net.load_state_dict(state_dict)

# example inference
model.eval()
model = model.cuda()
with torch.no_grad():
    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k
    logits=model(audio_wave) 
```



================================================
FILE: esc50/dataset.py
================================================
import io
import os
import pathlib
import random

import av
import librosa
import torchaudio
from torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler

import torch
from ba3l.ingredients.datasets import Dataset
import pandas as pd
from sacred.config import DynamicIngredient, CMD
from scipy.signal import convolve
from sklearn import preprocessing
from torch.utils.data import Dataset as TorchDataset
import numpy as np
import h5py
from helpers.audiodatasets import  PreprocessDataset


LMODE = os.environ.get("LMODE", False)

dataset = Dataset('Esc50')


@dataset.config
def default_config():
    name = 'esc50'  # dataset name
    normalize = False  # normalize dataset
    subsample = False  # subsample squares from the dataset
    roll = True  # apply roll augmentation
    fold = 1
    base_dir = "audioset_hdf5s/esc50/"  # base directory of the dataset as downloaded
    if LMODE:
        base_dir = "/system/user/publicdata/CP/audioset/audioset_hdf5s/esc50/"
    meta_csv = base_dir + "meta/esc50.csv"
    audio_path = base_dir + "audio_32k/"
    ir_path = base_dir + "irs/"
    num_of_classes = 50





def decode_mp3(mp3_arr):
    """
    decodes an array if uint8 representing an mp3 file
    :rtype: np.array
    """
    container = av.open(io.BytesIO(mp3_arr.tobytes()))
    stream = next(s for s in container.streams if s.type == 'audio')
    # print(stream)
    a = []
    for i, packet in enumerate(container.demux(stream)):
        for frame in packet.decode():
            a.append(frame.to_ndarray().reshape(-1))
    waveform = np.concatenate(a)
    if waveform.dtype != 'float32':
        raise RuntimeError("Unexpected wave type")
    return waveform


def pad_or_truncate(x, audio_length):
    """Pad all audio to specific length."""
    if len(x) <= audio_length:
        return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)
    else:
        return x[0: audio_length]


irs_arr = None


@dataset.command
def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
    if not ir_augment:
        return
    global irs_arr
    if irs_arr is None:
        all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]
        all_paths = sorted(all_paths)
        if cut_irs_offset is not None:
            all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]
        all_paths_name = [str(p).rsplit("/", 1)[-1] for p in all_paths]
        print("will use these IRs:")
        for i in range(len(all_paths_name)):
            print(i, ": ", all_paths_name[i])
        _run.info["ir_devices"] = all_paths_name
        irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]
    return irs_arr[int(np.random.randint(0, len(irs_arr)))]


@dataset.command
def pydub_augment(waveform, gain_augment=7, ir_augment=0):
    if ir_augment and torch.rand(1) < ir_augment:
        ir = get_ir_sample()
        waveform = convolve(waveform, ir, 'full')
    if gain_augment:
        gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment
        amp = 10 ** (gain / 20)
        waveform = waveform * amp
    return waveform


class MixupDataset(TorchDataset):
    """ Mixing Up wave forms
    """

    def __init__(self, dataset, beta=2, rate=0.5):
        self.beta = beta
        self.rate = rate
        self.dataset = dataset
        print(f"Mixing up waveforms from dataset of len {len(dataset)}")

    def __getitem__(self, index):
        if torch.rand(1) < self.rate:
            x1, f1, y1 = self.dataset[index]
            idx2 = torch.randint(len(self.dataset), (1,)).item()
            x2, f2, y2 = self.dataset[idx2]
            l = np.random.beta(self.beta, self.beta)
            l = max(l, 1. - l)
            x1 = x1 - x1.mean()
            x2 = x2 - x2.mean()
            x = (x1 * l + x2 * (1. - l))
            x = x - x.mean()
            return x, f1, (y1 * l + y2 * (1. - l))
        return self.dataset[index]

    def __len__(self):
        return len(self.dataset)




class AudioSetDataset(TorchDataset):
    def __init__(self, meta_csv,  audiopath, fold, train=False, sample_rate=32000, classes_num=527,
                 clip_length=5, augment=False):
        """
        Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav
        """
        self.sample_rate = sample_rate
        self.meta_csv = meta_csv
        self.df = pd.read_csv(meta_csv)
        if train:  # training all except this
            print(f"Dataset training fold {fold} selection out of {len(self.df)}")
            self.df = self.df[self.df.fold != fold]
            print(f" for training remains {len(self.df)}")
        else:
            print(f"Dataset testing fold {fold} selection out of {len(self.df)}")
            self.df = self.df[self.df.fold == fold]
            print(f" for testing remains {len(self.df)}")

        self.clip_length = clip_length * sample_rate
        self.sr = sample_rate
        self.classes_num = classes_num
        self.augment = augment
        self.audiopath=audiopath
        if augment:
            print(f"Will agument data from {meta_csv}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        """Load waveform and target of an audio clip.

        Args:
          meta: {
            'hdf5_path': str,
            'index_in_hdf5': int}
        Returns:
          data_dict: {
            'audio_name': str,
            'waveform': (clip_samples,),
            'target': (classes_num,)}
        """
        row = self.df.iloc[index]

        #waveform = decode_mp3(np.fromfile(self.audiopath + row.filename, dtype='uint8'))
        waveform, _ = librosa.load(self.audiopath + row.filename, sr=self.sr, mono=True)
        if self.augment:
            waveform = pydub_augment(waveform)
        waveform = pad_or_truncate(waveform, self.clip_length)
        waveform = self.resample(waveform)
        target = row.target
        return waveform.reshape(1, -1),  row.filename, target

    def resample(self, waveform):
        """Resample.
        Args:
          waveform: (clip_samples,)
        Returns:
          (resampled_clip_samples,)
        """
        if self.sample_rate == 32000:
            return waveform
        elif self.sample_rate == 16000:
            return waveform[0:: 2]
        elif self.sample_rate == 8000:
            return waveform[0:: 4]
        else:
            raise Exception('Incorrect sample rate!')



@dataset.command
def get_base_training_set(meta_csv, audio_path, fold=1):
    ds = AudioSetDataset(meta_csv, audio_path, fold,  train=True, augment=True)
    return ds


@dataset.command
def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sample_weights"),
                            epoch_len=100000, sampler_replace=False):
    num_nodes = int(os.environ.get('num_nodes', 1))
    ddp = int(os.environ.get('DDP', 1))
    num_nodes = max(ddp, num_nodes)
    print("num_nodes= ", num_nodes)
    rank = int(os.environ.get('NODE_RANK', 0))
    return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,
                                                                   num_samples=epoch_len, replacement=sampler_replace),
                                     dataset=range(epoch_len),
                                     num_replicas=num_nodes,
                                     rank=rank,
                                     )


@dataset.command
def get_base_test_set(meta_csv, audio_path, fold=1):
    ds = AudioSetDataset(meta_csv, audio_path, fold,  train=False)
    return ds



@dataset.command(prefix='roll_conf')
def get_roll_func(axis=1, shift=None, shift_range=50):
    print("rolling...")

    def roll_func(b):
        x, i, y = b
        x = torch.as_tensor(x)
        sf = shift
        if shift is None:
            sf = int(np.random.random_integers(-shift_range, shift_range))
        global FirstTime

        return x.roll(sf, axis), i, y

    return roll_func


@dataset.command
def get_training_set(normalize, roll, wavmix=False):
    ds = get_base_training_set()
    get_ir_sample()
    if normalize:
        print("normalized train!")
        fill_norms()
        ds = PreprocessDataset(ds, norm_func)
    if roll:
        ds = PreprocessDataset(ds, get_roll_func())
    if wavmix:
        ds = MixupDataset(ds)

    return ds


@dataset.command
def get_test_set(normalize):
    ds = get_base_test_set()
    if normalize:
        print("normalized test!")
        fill_norms()
        ds = PreprocessDataset(ds, norm_func)
    return ds


@dataset.command
def print_conf(_config):
    print("Config of ", dataset.path, id(dataset))
    print(_config)
    print()


class DistributedSamplerWrapper(DistributedSampler):
    def __init__(
            self, sampler, dataset,
            num_replicas=None,
            rank=None,
            shuffle: bool = True):
        super(DistributedSamplerWrapper, self).__init__(
            dataset, num_replicas, rank, shuffle)
        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
        self.sampler = sampler

    def __iter__(self):
        if self.sampler.generator is None:
            self.sampler.generator = torch.Generator()
        self.sampler.generator.manual_seed(self.seed + self.epoch)
        indices = list(self.sampler)
        if self.epoch == 0:
            print(f"\n DistributedSamplerWrapper :  {indices[:10]} \n\n")
        indices = indices[self.rank:self.total_size:self.num_replicas]
        return iter(indices)


if __name__ == "__main__":
    from sacred import Experiment

    ex = Experiment("test_dataset", ingredients=[dataset])


    @ex.automain
    def default_command():
        ex.current_run.get_command_function("print_config")()
        get_base_training_set()
        ds = get_test_set()
        print(ds[0])
        ds = get_training_set()
        print(ds[0])
        print("get_base_training_set", len(get_base_training_set()))
        print("get_base_test_set", len(get_base_test_set()))
        print("get_training_set", len(get_training_set()))
        print("get_test_set", len(get_test_set()))


================================================
FILE: ex_audioset.py
================================================
import os
import sys
import PIL
import pytorch_lightning
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np
import wandb

from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule

from torch.utils.data import DataLoader

from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor

ex = Experiment("audioset")

# Example call with all the default config:
# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# with 2 gpus:
# DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"

# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
# now you can use in the cmd trainer.precision=16 for example
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")
wandb_logger = None

# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
                                             num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_full_training_set"),
                                             sampler=CMD("/basedataset.get_ft_weighted_sampler"))

get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
                                            validate=True, batch_size=20, num_workers=16,
                                            dataset=CMD("/basedataset.get_test_set"))


@ex.config
def default_conf():
    cmd = " ".join(sys.argv)  # command line arguments
    saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
    saque_id = os.environ.get("SAQUE_ID", "").strip()
    slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
    if os.environ.get("SLURM_ARRAY_JOB_ID", False):
        slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
                                                                                               "").strip()
    process_id = os.getpid()
    models = {
        "net": DynamicIngredient("models.passt.model_ing", arch="passt_deit_bd_p16_384", n_classes=527, s_patchout_t=40,
                                 s_patchout_f=4),  # network config
        "mel": DynamicIngredient("models.preprocess.model_ing",
                                 instance_cmd="AugmentMelSTFT",
                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
                                 timem=192,
                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
                                 fmax_aug_range=2000)
    }
    basedataset = DynamicIngredient("audioset.dataset.dataset", wavmix=1)
    wandb = dict(project="passt_audioset2", log_model=True)
    watch_model = True
    trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, precision=16,
                   reload_dataloaders_every_epoch=True)
    lr = 0.00002  # learning rate
    use_mixup = True
    mixup_alpha = 0.3
    compile = True # compile the model, requires pytorch >= 2.0


# register extra possible configs
add_configs(ex)


@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
                         schedule_mode="exp_lin"):
    if schedule_mode == "exp_lin":
        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
    if schedule_mode == "cos_cyc":
        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
    raise RuntimeError(
        f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")


@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
    if schedule_mode in {"exp_lin", "cos_cyc"}:
        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")


@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
    if adamw:
        print(f"\nUsing adamw weight_decay={weight_decay}!\n")
        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    return torch.optim.Adam(params, lr=lr)


class M(Ba3lModule):
    def __init__(self, experiment):
        self.mel = None
        self.da_net = None
        super(M, self).__init__(experiment)

        self.use_mixup = self.config.use_mixup or False
        self.mixup_alpha = self.config.mixup_alpha

        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
        self.experiment.info["start_sum_params"] = sum_params
        self.experiment.info["start_sum_params_non_zero"] = sum_non_zero

        # in case we need embedings for the DA
        self.net.return_embed = True
        self.dyn_norm = self.config.dyn_norm
        self.do_swa = False

        self.distributed_mode = self.config.trainer.num_nodes > 1
        
        if self.config.compile:
            # pt 2 magic
            print("\n\nCompiling the model pytorch 2... \n\n")
            self.net = torch.compile(self.net)
            # compile only the net, not the mel
            #self.mel = torch.compile(self.mel)

    def forward(self, x):
        return self.net(x)

    def mel_forward(self, x):
        old_shape = x.size()
        x = x.reshape(-1, old_shape[2])
        x = self.mel(x)
        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
        if self.dyn_norm:
            if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
                tr_m, tr_std = get_dynamic_norm(self)
                self.register_buffer('tr_m', tr_m)
                self.register_buffer('tr_std', tr_std)
            x = (x - self.tr_m) / self.tr_std
        return x

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)
        
        if self.global_step < 5:
            images = [ wandb.Image(
                PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"),
                caption="spectrograms",
            ) for i in x[:, 0, :, :].cpu().numpy()]
            wandb.log({"spectrograms": images})
            # wandb_logger.log_image(key="spectrograms", images=[i for i in x[:,0,:,:].cpu().numpy()])

        orig_x = x
        batch_size = len(y)

        rn_indices, lam = None, None
        if self.use_mixup:
            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
            lam = lam.to(x.device)
            x = x * lam.reshape(batch_size, 1, 1, 1) + \
                x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))

        y_hat, embed = self.forward(x)

        if self.use_mixup:
            y_mix = y * lam.reshape(batch_size, 1) + \
                y[rn_indices] * (1. - lam.reshape(batch_size, 1))
            samples_loss = F.binary_cross_entropy_with_logits(
                y_hat, y_mix, reduction="none")
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()
        else:
            samples_loss = F.binary_cross_entropy_with_logits(
                y_hat, y, reduction="none")
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()

        results = {"loss": loss, }
        self.log('trainer/lr', self.trainer.optimizers[0].param_groups[0]['lr'])
        self.log('epoch', self.current_epoch)
        self.log("training.loss", loss.detach())
        return results

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        logs = {'train.loss': avg_loss, # 'step': self.current_epoch
                }

        self.log_dict(logs, sync_dist=True)

    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        y_hat, _ = self.forward(x)
        return f, y_hat

    def validation_step(self, batch, batch_idx):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        if self.global_step < 5:
            images = [ wandb.Image(
                PIL.Image.fromarray((i * 255).astype(np.uint8)).convert("L"),
                caption="validation_spectrograms",
            ) for i in x[:, 0, :, :].cpu().numpy()]
            wandb.log({"validation_spectrograms": images})
            # wandb_logger.log_image(key="validation_spectrograms", images=[
            #                        i for i in x[:, 0, :, :].cpu().numpy()])

        results = {}
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            y_hat, _ = net(x)
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
            loss = samples_loss.mean()
            out = torch.sigmoid(y_hat.detach())
            # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
            results = {**results, net_name + "val_loss": loss,
                       net_name + "out": out, net_name + "target": y.detach()}
        results = {k: v.cpu() for k, v in results.items()}
        return results

    def validation_epoch_end(self, outputs):
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            avg_loss = torch.stack([x[net_name + 'val_loss']
                                   for x in outputs]).mean()
            out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
            target = torch.cat([x[net_name + 'target']
                               for x in outputs], dim=0)
            try:
                average_precision = metrics.average_precision_score(
                    target.float().numpy(), out.float().numpy(), average=None)
            except ValueError:
                average_precision = np.array([np.nan] * 527)
            try:
                roc = metrics.roc_auc_score(
                    target.numpy(), out.numpy(), average=None)
            except ValueError:
                roc = np.array([np.nan] * 527)
            logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
                    net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
                    net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),}
                    #'step': torch.as_tensor(self.current_epoch).cuda()}
            
            # torch.save(average_precision,
            #            f"ap_perclass_{average_precision.mean()}.pt")
            # print(average_precision)
            self.log_dict(logs, sync_dist=True)
            if self.distributed_mode:
                allout = self.all_gather(out)
                alltarget = self.all_gather(target)

                average_precision = metrics.average_precision_score(
                    alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),
                    allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)
                if self.trainer.is_global_zero:
                    logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
                           # 'step': torch.as_tensor(self.current_epoch).cuda()
                            }
                    self.log_dict(logs, sync_dist=False)
            else:
                self.log_dict(
                    {net_name + "allap": logs[net_name + 'ap'], 
                     #'step': logs['step']
                      }
                    , sync_dist=True)

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = get_optimizer(self.parameters())
        # torch.optim.Adam(self.parameters(), lr=self.config.lr)
        return {
            'optimizer': optimizer,
            'lr_scheduler': get_lr_scheduler(optimizer)
        }

    def configure_callbacks(self):
        return get_extra_checkpoint_callback() + get_extra_swa_callback() + [LearningRateMonitor(logging_interval='epoch')]


@ex.command
def get_dynamic_norm(model, dyn_norm=False):
    if not dyn_norm:
        return None, None
    raise RuntimeError('no dynamic norm supported yet.')


@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
    if save_last_n is None:
        return []
    return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]


@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=50,
                           swa_freq=5):
    if not swa:
        return []
    print("\n Using swa!\n")
    if pytorch_lightning.__version__ <= "1.5":
        from helpers.swa_legacy import StochasticWeightAveraging
    else:
        from helpers.swa_callback import StochasticWeightAveraging
    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]


@ex.command
def main(_run, _config, _log, _rnd, _seed, watch_model=True):
    global wandb_logger 
    wandb_logger = get_logger()
    trainer = get_trainer(logger=wandb_logger)
    train_loader = get_train_loader()
    val_loader = get_validate_loader()

    modul = M(ex)
    if watch_model:
        # change log frequency of gradients and parameters (100 steps by default)
        wandb_logger.watch(modul, log_freq=1000, log="all" )

    if pytorch_lightning.__version__ <= "1.5":
        trainer.fit(
            modul,
            train_dataloader=train_loader,
            val_dataloaders=val_loader,
        )
    else:
        trainer.fit(
            modul,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader,
        )

    return {"done": True}


@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=12):
    '''
    Test training speed of a model
    @param _run:
    @param _config:
    @param _log:
    @param _rnd:
    @param _seed:
    @param speed_test_batch_size: the batch size during the test
    @return:
    '''

    modul = M(ex)
    modul = modul.cuda()
    batch_size = speed_test_batch_size
    print(f"\nBATCH SIZE : {batch_size}\n")
    test_length = 100
    print(f"\ntest_length : {test_length}\n")

    x = torch.ones([batch_size, 1, 128, 998]).cuda()
    target = torch.ones([batch_size, 527]).cuda()
    # one passe
    net = modul.net
    # net(x)
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True
    # net = torch.jit.trace(net,(x,))
    net = torch.compile(net)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

    print("warmup")
    import time
    torch.cuda.synchronize()
    t1 = time.time()
    for i in range(10):
        with torch.cuda.amp.autocast():
            y_hat, embed = net(x)
            loss = F.binary_cross_entropy_with_logits(
                y_hat, target, reduction="none").mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    t2 = time.time()
    print('warmup done:', (t2 - t1))
    torch.cuda.synchronize()
    t1 = time.time()
    print("testing speed")

    for i in range(test_length):
        with torch.cuda.amp.autocast():
            y_hat, embed = net(x)
            loss = F.binary_cross_entropy_with_logits(
                y_hat, target, reduction="none").mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    t2 = time.time()
    print('test done:', (t2 - t1))
    print("average speed: ", (test_length * batch_size) /
          (t2 - t1), " specs/second")


@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
    # force overriding the config, not logged = not recommended
    trainer = get_trainer(logger=get_logger())
    val_loader = get_validate_loader()

    modul = M(ex)
    modul.val_dataloader = None
    #trainer.val_dataloaders = None
    print(f"\n\nValidation len={len(val_loader)}\n")
    res = trainer.validate(modul, dataloaders=val_loader)
    print("\n\n Validtaion:")
    print(res)


@ex.command
def test_loaders():
    '''
    get one sample from each loader for debbuging
    @return:
    '''
    for i, b in enumerate(ex.datasets.training.get_iter()):
        print(b)
        break

    for i, b in enumerate(ex.datasets.test.get_iter()):
        print(b)
        break


def set_default_json_pickle(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError


@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
    '''
    read the dataset sequentially, useful if you have a network cache
    @param all_y: the dataset preload command
    @return:
    '''
    print(all_y.shape)


def multiprocessing_run(rank, word_size):
    print("rank ", rank, os.getpid())
    print("word_size ", word_size)
    os.environ['NODE_RANK'] = str(rank)
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[
        rank]
    argv = sys.argv
    if rank != 0:
        print(f"Unobserved {os.getpid()} with rank {rank}")
        argv = argv + ["-u"]  # only rank 0 is observed
    if "with" not in argv:
        argv = argv + ["with"]

    argv = argv + \
        [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
    print(argv)

    @ex.main
    def default_command():
        return main()

    ex.run_commandline(argv)


if __name__ == '__main__':
    # set DDP=2 forks two processes to run on two GPUs
    # the environment variable "DDP" define the number of processes to fork
    # With two 2x 2080ti you can train the full model to .47 in around 24 hours
    # you may need to set NCCL_P2P_DISABLE=1
    word_size = os.environ.get("DDP", None)
    if word_size:
        import random

        word_size = int(word_size)
        print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        # plz no collisions
        os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}"
        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'

        for rank in range(word_size):
            pid = os.fork()
            if pid == 0:
                print("Child Forked ")
                multiprocessing_run(rank, word_size)
                exit(0)

        pid, exit_code = os.wait()
        print(pid, exit_code)
        exit(0)

print("__main__ is running pid", os.getpid(), "in module main: ", __name__)


@ex.automain
def default_command():
    return main()


================================================
FILE: ex_esc50.py
================================================
import os
import sys

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np

from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule

from torch.utils.data import DataLoader

from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger



ex = Experiment("passt_esc50")

# Example call with all the default config:
# python ex_esc50.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base"
# with 2 gpus:
# DDP=2 python ex_esc50.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c "ESC50 PaSST base"

# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")


# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
                          num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"),
                          )

get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
                                            validate=True, batch_size=20, num_workers=16,
                                            dataset=CMD("/basedataset.get_test_set"))


@ex.config
def default_conf():
    cmd = " ".join(sys.argv)
    saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
    saque_id = os.environ.get("SAQUE_ID", "").strip()
    slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
    if os.environ.get("SLURM_ARRAY_JOB_ID", False):
        slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
                                                                                               "").strip()
    process_id = os.getpid()
    models = {
        "net": DynamicIngredient("models.passt.model_ing", n_classes=50, s_patchout_t=10, s_patchout_f=3),
        "mel": DynamicIngredient("models.preprocess.model_ing",
                                 instance_cmd="AugmentMelSTFT",
                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
                                 timem=80,
                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
                                 fmax_aug_range=2000)
    }
    basedataset = DynamicIngredient("esc50.dataset.dataset")
    trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
                   reload_dataloaders_every_epoch=True)
    wandb = dict(project="passt_esc50", log_model=True)
    lr = 0.00001
    use_mixup = True
    mixup_alpha = 0.3


# register extra possible configs
add_configs(ex)


@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
                         schedule_mode="exp_lin"):
    if schedule_mode == "exp_lin":
        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
    if schedule_mode == "cos_cyc":
        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")


@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
    if schedule_mode in {"exp_lin", "cos_cyc"}:
        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")


@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
    if adamw:
        print(f"\nUsing adamw weight_decay={weight_decay}!\n")
        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    return torch.optim.Adam(params, lr=lr)


class M(Ba3lModule):
    def __init__(self, experiment):
        self.mel = None
        self.da_net = None
        super(M, self).__init__(experiment)

        self.use_mixup = self.config.use_mixup or False
        self.mixup_alpha = self.config.mixup_alpha

        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
        self.experiment.info["start_sum_params"] = sum_params
        self.experiment.info["start_sum_params_non_zero"] = sum_non_zero

        # in case we need embedings for the DA
        self.net.return_embed = True
        self.dyn_norm = self.config.dyn_norm
        self.do_swa = False

        self.distributed_mode = self.config.trainer.num_nodes > 1

    def forward(self, x):
        return self.net(x)

    def mel_forward(self, x):
        old_shape = x.size()
        x = x.reshape(-1, old_shape[2])
        x = self.mel(x)
        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
        if self.dyn_norm:
            if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
                tr_m, tr_std = get_dynamic_norm(self)
                self.register_buffer('tr_m', tr_m)
                self.register_buffer('tr_std', tr_std)
            x = (x - self.tr_m) / self.tr_std
        return x

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        orig_x = x
        batch_size = len(y)

        rn_indices, lam = None, None
        if self.use_mixup:
            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
            lam = lam.to(x.device)
            x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))

        y_hat, embed = self.forward(x)

        if self.use_mixup:
            # y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
            samples_loss = (F.cross_entropy(y_hat, y, reduction="none") * lam.reshape(batch_size) +
                            F.cross_entropy(y_hat, y[rn_indices], reduction="none") * (1. - lam.reshape(batch_size)))
            loss = samples_loss.mean()
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()
        else:
            samples_loss = F.cross_entropy(y_hat, y, reduction="none")
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()

        results = {"loss": loss, }

        return results

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        logs = {'train.loss': avg_loss, 'step': self.current_epoch}

        self.log_dict(logs, sync_dist=True)

    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        y_hat, _ = self.forward(x)
        return f, y_hat

    def validation_step(self, batch, batch_idx):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        results = {}
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            y_hat, _ = net(x)
            samples_loss = F.cross_entropy(y_hat, y, reduction="none")
            loss = samples_loss.mean()
            _, preds = torch.max(y_hat, dim=1)
            n_correct_pred_per_sample = (preds == y)
            n_correct_pred = n_correct_pred_per_sample.sum()
            # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
            results = {**results, net_name + "val_loss": loss,
                       net_name + "n_correct_pred": torch.as_tensor(n_correct_pred), net_name + "n_pred":torch.as_tensor( len(y)) }
        results = {k: v.cpu() for k, v in results.items()}
        return results

    def validation_epoch_end(self, outputs):
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
            val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)
            logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
                    net_name + 'acc': torch.as_tensor(val_acc).cuda(),
                    'step': torch.as_tensor(self.current_epoch).cuda()}
            self.log_dict(logs, sync_dist=True)

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = get_optimizer(self.parameters())
        # torch.optim.Adam(self.parameters(), lr=self.config.lr)
        return {
            'optimizer': optimizer,
            'lr_scheduler': get_lr_scheduler(optimizer)
        }

    def configure_callbacks(self):
        return get_extra_checkpoint_callback() + get_extra_swa_callback()


@ex.command
def get_dynamic_norm(model, dyn_norm=False):
    if not dyn_norm:
        return None, None
    raise RuntimeError('no dynamic norm supported yet.')


@ex.command
def get_extra_checkpoint_callback(save_last_n=None):
    if save_last_n is None:
        return []
    return [ModelCheckpoint(monitor="step", verbose=True, save_top_k=save_last_n, mode='max')]


@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=2,
                           swa_freq=1):
    if not swa:
        return []
    print("\n Using swa!\n")
    from helpers.swa_callback import StochasticWeightAveraging
    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]


@ex.command
def main(_run, _config, _log, _rnd, _seed):
    trainer = get_trainer()
    train_loader = get_train_loader()
    val_loader = get_validate_loader()

    modul = M(ex)

    trainer.fit(
        modul,
        train_dataloaders=train_loader,
        val_dataloaders=val_loader,
    )

    return {"done": True}


@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
    '''
    Test training speed of a model
    @param _run:
    @param _config:
    @param _log:
    @param _rnd:
    @param _seed:
    @param speed_test_batch_size: the batch size during the test
    @return:
    '''

    modul = M(ex)
    modul = modul.cuda()
    batch_size = speed_test_batch_size
    print(f"\nBATCH SIZE : {batch_size}\n")
    test_length = 100
    print(f"\ntest_length : {test_length}\n")

    x = torch.ones([batch_size, 1, 128, 998]).cuda()
    target = torch.ones([batch_size, 527]).cuda()
    # one passe
    net = modul.net
    # net(x)
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True
    # net = torch.jit.trace(net,(x,))
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

    print("warmup")
    import time
    torch.cuda.synchronize()
    t1 = time.time()
    for i in range(10):
        with  torch.cuda.amp.autocast():
            y_hat, embed = net(x)
            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    t2 = time.time()
    print('warmup done:', (t2 - t1))
    torch.cuda.synchronize()
    t1 = time.time()
    print("testing speed")

    for i in range(test_length):
        with  torch.cuda.amp.autocast():
            y_hat, embed = net(x)
            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    t2 = time.time()
    print('test done:', (t2 - t1))
    print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")


@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
    # force overriding the config, not logged = not recommended
    trainer = get_trainer()
    train_loader = get_train_loader()
    val_loader = get_validate_loader()

    modul = M(ex)
    modul.val_dataloader = None
    #trainer.val_dataloaders = None
    print(f"\n\nValidation len={len(val_loader)}\n")
    res = trainer.validate(modul, dataloaders=val_loader)
    print("\n\n Validtaion:")
    print(res)


@ex.command
def test_loaders():
    '''
    get one sample from each loader for debbuging
    @return:
    '''
    for i, b in enumerate(ex.datasets.training.get_iter()):
        print(b)
        break

    for i, b in enumerate(ex.datasets.test.get_iter()):
        print(b)
        break


def set_default_json_pickle(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError



def multiprocessing_run(rank, word_size):
    print("rank ", rank, os.getpid())
    print("word_size ", word_size)
    os.environ['NODE_RANK'] = str(rank)
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
    argv = sys.argv
    if rank != 0:
        print(f"Unobserved {os.getpid()} with rank {rank}")
        argv = argv + ["-u"]  # only rank 0 is observed
    if "with" not in argv:
        argv = argv + ["with"]

    argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
    print(argv)

    @ex.main
    def default_command():
        return main()

    ex.run_commandline(argv)


if __name__ == '__main__':
    # set DDP=2 forks two processes to run on two GPUs
    # the environment variable "DDP" define the number of processes to fork
    # With two 2x 2080ti you can train the full model to .47 in around 24 hours
    # you may need to set NCCL_P2P_DISABLE=1
    word_size = os.environ.get("DDP", None)
    if word_size:
        import random

        word_size = int(word_size)
        print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}"  # plz no collisions
        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'

        for rank in range(word_size):
            pid = os.fork()
            if pid == 0:
                print("Child Forked ")
                multiprocessing_run(rank, word_size)
                exit(0)

        pid, exit_code = os.wait()
        print(pid, exit_code)
        exit(0)

print("__main__ is running pid", os.getpid(), "in module main: ", __name__)


@ex.automain
def default_command():
    return main()


================================================
FILE: ex_fsd50k.py
================================================
import os
import sys

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np

from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule

from torch.utils.data import DataLoader

from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger


ex = Experiment("fsd50k")

# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")



# Example call with all the default config:
# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# with 2 gpus:
# DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"

# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
                          num_workers=16, shuffle=True, dataset=CMD("/basedataset.get_training_set"),
                          )

get_validate_loader = ex.datasets.valid.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
                                             validate=True, batch_size=10, num_workers=16,
                                             dataset=CMD("/basedataset.get_valid_set"))

get_eval_loader = ex.datasets.eval.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
                                        validate=True, batch_size=10, num_workers=16,
                                        dataset=CMD("/basedataset.get_eval_set"))


@ex.named_config
def variable_eval():
    basedataset = dict(variable_eval=True)
    datasets = dict(valid=dict(batch_size=1), eval=dict(batch_size=1))


@ex.config
def default_conf():
    cmd = " ".join(sys.argv)  # command line arguments
    saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
    saque_id = os.environ.get("SAQUE_ID", "").strip()
    slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
    if os.environ.get("SLURM_ARRAY_JOB_ID", False):
        slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
                                                                                               "").strip()
    process_id = os.getpid()
    models = {
        "net": DynamicIngredient("models.passt.model_ing",
                                 n_classes=200, s_patchout_t=10, s_patchout_f=4),  # network config
        "mel": DynamicIngredient("models.preprocess.model_ing",
                                 instance_cmd="AugmentMelSTFT",
                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=0,
                                 timem=0,
                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
                                 fmax_aug_range=2000)
    }
    # set the default name for wandb logger
    wandb = dict(project="passt_fsd50k", log_model=True)
    basedataset = DynamicIngredient("fsd50k.dataset.dataset", wavmix=1)
    trainer = dict(max_epochs=50, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
                   reload_dataloaders_every_epoch=True)
    lr = 0.00001  # learning rate
    use_mixup = True
    mixup_alpha = 0.3


# register extra possible configs
add_configs(ex)


@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_len=10, last_lr_value=0.01,
                         schedule_mode="exp_lin"):
    if schedule_mode == "exp_lin":
        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
    if schedule_mode == "cos_cyc":
        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")


@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
    if schedule_mode in {"exp_lin", "cos_cyc"}:
        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")


@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
    if adamw:
        print(f"\nUsing adamw weight_decay={weight_decay}!\n")
        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    return torch.optim.Adam(params, lr=lr)


class M(Ba3lModule):
    def __init__(self, experiment):
        self.mel = None
        self.da_net = None
        super(M, self).__init__(experiment)

        self.use_mixup = self.config.use_mixup or False
        self.mixup_alpha = self.config.mixup_alpha

        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
        self.experiment.info["start_sum_params"] = sum_params
        self.experiment.info["start_sum_params_non_zero"] = sum_non_zero

        # in case we need embedings for the DA
        self.net.return_embed = True
        self.dyn_norm = self.config.dyn_norm
        self.do_swa = False
        self.valid_names = ["valid", "eval"]
        self.distributed_mode = self.config.trainer.num_nodes > 1

    def forward(self, x):
        return self.net(x)

    def mel_forward(self, x):
        old_shape = x.size()
        x = x.reshape(-1, old_shape[2])
        x = self.mel(x)
        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
        if self.dyn_norm:
            if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
                tr_m, tr_std = get_dynamic_norm(self)
                self.register_buffer('tr_m', tr_m)
                self.register_buffer('tr_std', tr_std)
            x = (x - self.tr_m) / self.tr_std
        return x

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        orig_x = x
        batch_size = len(y)

        rn_indices, lam = None, None
        if self.use_mixup:
            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
            lam = lam.to(x.device)
            x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))

        y_hat, embed = self.forward(x)

        if self.use_mixup:
            y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
            samples_loss = F.binary_cross_entropy_with_logits(
                y_hat, y_mix, reduction="none")
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()
        else:
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()

        results = {"loss": loss, }

        return results

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        logs = {'train.loss': avg_loss, 'step': self.current_epoch}

        self.log_dict(logs, sync_dist=True)

    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        y_hat, _ = self.forward(x)
        return f, y_hat

    def validation_step(self, batch, batch_idx, dataloader_idx):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        results = {}
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            y_hat, _ = net(x)
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
            loss = samples_loss.mean()
            out = torch.sigmoid(y_hat.detach())
            # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
            results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach()}
        results = {k: v.cpu() for k, v in results.items()}
        return results

    def validation_epoch_end(self, outputs):
        for idx, one_outputs in enumerate(outputs):
            set_name = self.valid_names[idx] + "_"
            model_name = [("", self.net)]
            if self.do_swa:
                model_name = model_name + [("swa_", self.net_swa)]
            for net_name, net in model_name:
                avg_loss = torch.stack([x[net_name + 'val_loss'] for x in one_outputs]).mean()
                out = torch.cat([x[net_name + 'out'] for x in one_outputs], dim=0)
                target = torch.cat([x[net_name + 'target'] for x in one_outputs], dim=0)
                try:
                    average_precision = metrics.average_precision_score(
                        target.float().numpy(), out.float().numpy(), average=None)
                except ValueError:
                    average_precision = np.array([np.nan] * 200)
                try:
                    roc = metrics.roc_auc_score(target.numpy(), out.numpy(), average=None)
                except ValueError:
                    roc = np.array([np.nan] * 200)
                logs = {set_name + net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
                        set_name + net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
                        set_name + net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
                        'step': torch.as_tensor(self.current_epoch).cuda()}
                # torch.save(average_precision, f"ap_perclass_{average_precision.mean()}.pt")
                # print(average_precision)
                self.log_dict(logs, sync_dist=True)
                if self.distributed_mode:
                    allout = self.all_gather(out)
                    alltarget = self.all_gather(target)

                    average_precision = metrics.average_precision_score(
                        alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),
                        allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)
                    if self.trainer.is_global_zero:
                        logs = {set_name + net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
                                'step': torch.as_tensor(self.current_epoch).cuda()}
                        self.log_dict(logs, sync_dist=False)
                else:
                    self.log_dict(
                        {set_name + net_name + "allap": logs[set_name + net_name + 'ap'], 'step': logs['step']},
                        sync_dist=True)

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = get_optimizer(self.parameters())
        # torch.optim.Adam(self.parameters(), lr=self.config.lr)
        return {
            'optimizer': optimizer,
            'lr_scheduler': get_lr_scheduler(optimizer)
        }

    def configure_callbacks(self):
        return get_extra_checkpoint_callback() + get_extra_swa_callback()


@ex.command
def get_dynamic_norm(model, dyn_norm=False):
    if not dyn_norm:
        return None, None
    raise RuntimeError('no dynamic norm supported yet.')


model_checkpoint_callback = None


@ex.command
def get_extra_checkpoint_callback(save_best=None):
    if save_best is None:
        return []
    global model_checkpoint_callback
    model_checkpoint_callback = ModelCheckpoint(monitor="allap", verbose=True, save_top_k=save_best, mode='max',
                                                every_n_val_epochs=1, every_n_train_steps=0)
    return [model_checkpoint_callback]


@ex.command
def get_extra_swa_callback(swa=True, swa_epoch_start=10,
                           swa_freq=3):
    if not swa:
        return []
    print("\n Using swa!\n")
    from helpers.swa_callback import StochasticWeightAveraging
    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]


@ex.command
def main(_run, _config, _log, _rnd, _seed):
    trainer = get_trainer(logger=get_logger())
    train_loader = get_train_loader()
    val_loader = get_validate_loader()
    eval_loader = get_eval_loader()

    # eval_loader = get_eval_loader()

    modul = M(ex)

    trainer.fit(
        modul,
        train_dataloader=train_loader,
        val_dataloaders=[val_loader, eval_loader],
    )
    ## evaluate best model on eval set
    #trainer.val_dataloaders = None
    modul.val_dataloaders = None

    return {"done": True}


@ex.command
def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):
    '''
    Test training speed of a model
    @param _run:
    @param _config:
    @param _log:
    @param _rnd:
    @param _seed:
    @param speed_test_batch_size: the batch size during the test
    @return:
    '''

    modul = M(ex)
    modul = modul.cuda()
    batch_size = speed_test_batch_size
    print(f"\nBATCH SIZE : {batch_size}\n")
    test_length = 100
    print(f"\ntest_length : {test_length}\n")

    x = torch.ones([batch_size, 1, 128, 998]).cuda()
    target = torch.ones([batch_size, 527]).cuda()
    # one passe
    net = modul.net
    # net(x)
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True
    # net = torch.jit.trace(net,(x,))
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

    print("warmup")
    import time
    torch.cuda.synchronize()
    t1 = time.time()
    for i in range(10):
        with  torch.cuda.amp.autocast():
            y_hat, embed = net(x)
            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    t2 = time.time()
    print('warmup done:', (t2 - t1))
    torch.cuda.synchronize()
    t1 = time.time()
    print("testing speed")

    for i in range(test_length):
        with  torch.cuda.amp.autocast():
            y_hat, embed = net(x)
            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    torch.cuda.synchronize()
    t2 = time.time()
    print('test done:', (t2 - t1))
    print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")


@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
    # force overriding the config, not logged = not recommended
    trainer = get_trainer()
    train_loader = get_train_loader()
    val_loader = get_validate_loader()
    modul = M(ex)
    modul.val_dataloader = None
    #trainer.val_dataloaders = None
    print(f"\n\nValidation len={len(val_loader)}\n")
    res = trainer.validate(modul, dataloaders=val_loader)
    print("\n\n Validtaion:")
    print(res)


@ex.command
def test_loaders():
    '''
    get one sample from each loader for debbuging
    @return:
    '''
    for i, b in enumerate(ex.datasets.training.get_iter()):
        print(b)
        break

    for i, b in enumerate(ex.datasets.test.get_iter()):
        print(b)
        break


def set_default_json_pickle(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError


@ex.command
def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
    '''
    read the dataset sequentially, useful if you have a network cache
    @param all_y: the dataset preload command
    @return:
    '''
    print(all_y.shape)


def multiprocessing_run(rank, word_size):
    print("rank ", rank, os.getpid())
    print("word_size ", word_size)
    os.environ['NODE_RANK'] = str(rank)
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
    argv = sys.argv
    if rank != 0:
        print(f"Unobserved {os.getpid()} with rank {rank}")
        argv = argv + ["-u"]  # only rank 0 is observed
    if "with" not in argv:
        argv = argv + ["with"]

    argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
    print(argv)

    @ex.main
    def default_command():
        return main()

    ex.run_commandline(argv)


if __name__ == '__main__':
    # set DDP=2 forks two processes to run on two GPUs
    # the environment variable "DDP" define the number of processes to fork
    # With two 2x 2080ti you can train the full model to .47 in around 24 hours
    # you may need to set NCCL_P2P_DISABLE=1
    word_size = os.environ.get("DDP", None)
    if word_size:
        import random

        word_size = int(word_size)
        print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}"  # plz no collisions
        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'

        for rank in range(word_size):
            pid = os.fork()
            if pid == 0:
                print("Child Forked ")
                multiprocessing_run(rank, word_size)
                exit(0)

        pid, exit_code = os.wait()
        print(pid, exit_code)
        exit(0)

print("__main__ is running pid", os.getpid(), "in module main: ", __name__)


@ex.automain
def default_command():
    return main()


================================================
FILE: ex_openmic.py
================================================
import os
import sys

import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from sacred.config_helpers import DynamicIngredient, CMD
from torch.nn import functional as F
import numpy as np

from ba3l.experiment import Experiment
from ba3l.module import Ba3lModule

from torch.utils.data import DataLoader

from config_updates import add_configs
from helpers.mixup import my_mixup
from helpers.models_size import count_non_zero_params
from helpers.ramp import exp_warmup_linear_down, cosine_cycle
from helpers.workersinit import worker_init_fn
from sklearn import metrics
from pytorch_lightning import Trainer as plTrainer
from pytorch_lightning.loggers import WandbLogger







ex = Experiment("openmic")

# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")


# Example call with all the default config:
# python ex_openmic.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base"
# with 2 gpus:
# DDP=2 python ex_openmic.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c "OpenMIC PaSST base"

# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=6,
                          num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_training_set"),
                          )

get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
                                            validate=True, batch_size=20, num_workers=16,
                                            dataset=CMD("/basedataset.get_test_set"))


@ex.config
def default_conf():
    cmd = " ".join(sys.argv)
    saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
    saque_id = os.environ.get("SAQUE_ID", "").strip()
    slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
    if os.environ.get("SLURM_ARRAY_JOB_ID", False):
        slurm_job_id = os.environ.get("SLURM_ARRAY_JOB_ID", "").strip() + "_" + os.environ.get("SLURM_ARRAY_TASK_ID",
                                                                                               "").strip()
    process_id = os.getpid()
    models = {
        "net": DynamicIngredient("models.passt.model_ing", n_classes=20, s_patchout_t=40, s_patchout_f=4),
        "mel": DynamicIngredient("models.preprocess.model_ing",
                                 instance_cmd="AugmentMelSTFT",
                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
                                 timem=192,
                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
                                 fmax_aug_range=2000)
    }
    wandb = dict(project="passt_openmic", log_model=True)
    basedataset = DynamicIngredient("openmic.dataset.dataset", wavmix=1)
    # set the default for the trainer
    trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
                   reload_dataloaders_every_epoch=True)
    lr = 0.00001
    use_mixup = True
    mixup_alpha = 0.3



# register extra possible configs
add_configs(ex)


@ex.command
def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,
                         schedule_mode="exp_lin"):
    if schedule_mode == "exp_lin":
        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
    if schedule_mode == "cos_cyc":
        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")


@ex.command
def get_lr_scheduler(optimizer, schedule_mode):
    if schedule_mode in {"exp_lin", "cos_cyc"}:
        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())
    raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")


@ex.command
def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
    if adamw:
        print(f"\nUsing adamw weight_decay={weight_decay}!\n")
        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    return torch.optim.Adam(params, lr=lr)


class M(Ba3lModule):
    def __init__(self, experiment):
        self.mel = None
        self.da_net = None
        super(M, self).__init__(experiment)

        self.use_mixup = self.config.use_mixup or False
        self.mixup_alpha = self.config.mixup_alpha

        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)
        self.experiment.info["start_sum_params"] = sum_params
        self.experiment.info["start_sum_params_non_zero"] = sum_non_zero

        # in case we need embedings for the DA
        self.net.return_embed = True
        self.dyn_norm = self.config.dyn_norm
        self.do_swa = False

        self.distributed_mode = self.config.trainer.num_nodes > 1
        

        

    def forward(self, x):
        return self.net(x)

    def mel_forward(self, x):
        old_shape = x.size()
        x = x.reshape(-1, old_shape[2])
        x = self.mel(x)
        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
        if self.dyn_norm:
            if not hasattr(self, "tr_m") or not hasattr(self, "tr_std"):
                tr_m, tr_std = get_dynamic_norm(self)
                self.register_buffer('tr_m', tr_m)
                self.register_buffer('tr_std', tr_std)
            x = (x - self.tr_m) / self.tr_std
        return x

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        y_mask = y[:, 20:]
        y = y[:, :20] > 0.5
        y = y.float()

        orig_x = x
        batch_size = len(y)

        rn_indices, lam = None, None
        if self.use_mixup:
            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
            lam = lam.to(x.device)
            x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))

        y_hat, embed = self.forward(x)

        if self.use_mixup:
            y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
            samples_loss = F.binary_cross_entropy_with_logits(
                y_hat, y_mix, reduction="none")
            y_mix_mask = ((y_mask > 0.5) | (y_mask[rn_indices] > 0.5)).float()
            samples_loss = y_mask.float() * samples_loss
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()
        else:
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
            samples_loss = y_mask.float() * samples_loss
            loss = samples_loss.mean()
            samples_loss = samples_loss.detach()

        results = {"loss": loss, }

        return results

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        logs = {'train.loss': avg_loss, 'step': self.current_epoch}

        self.log_dict(logs, sync_dist=True)

    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)

        y_hat, _ = self.forward(x)
        return f, y_hat

    def validation_step(self, batch, batch_idx):
        x, f, y = batch
        if self.mel:
            x = self.mel_forward(x)
        y_mask = y[:, 20:]
        y = y[:, :20] > 0.5
        y = y.float()
        results = {}
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            y_hat, _ = net(x)
            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)
            samples_loss = y_mask.float() * samples_loss
            loss = samples_loss.mean()
            out = torch.sigmoid(y_hat.detach())
            # self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
            results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach(),
                       net_name + "mask": y_mask.detach()}
        results = {k: v.cpu() for k, v in results.items()}
        return results

    def validation_epoch_end(self, outputs):
        model_name = [("", self.net)]
        if self.do_swa:
            model_name = model_name + [("swa_", self.net_swa)]
        for net_name, net in model_name:
            avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
            out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
            target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0)
            mask = torch.cat([x[net_name + 'mask'] for x in outputs], dim=0)
            try:
                y_true = target.float().numpy()
                y_pred = out.float().numpy()
                y_mask = mask.float().numpy()
                average_precision = np.array([metrics.average_precision_score(
                    y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
            except ValueError:
                average_precision = np.array([np.nan] * y_true.shape[1])
            #torch.save(average_precision, f"ap_openmic_perclass_{average_precision.mean()}.pt")
            try:
                roc = np.array([metrics.roc_auc_score(
                    y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
            except ValueError:
                roc = np.array([np.nan] * y_true.shape[1])
            logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
                    net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
                    net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
                    'step': torch.as_tensor(self.current_epoch).cuda()}
            self.log_dict(logs, sync_dist=True)
            if self.distributed_mode:
                allout = self.all_gather(out)
                alltarget = self.all_gather(target)
                all_mask = self.all_gather(mask)
                y_true = alltarget.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
                y_pred = allout.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
                y_mask = all_mask.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()
                average_precision = np.array([metrics.average_precision_score(
                    y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
                if self.trainer.is_global_zero:
                    logs = {net_name + "allap": torch.as_tensor(average_precision.mean()).cuda(),
                            'step': torch.as_tensor(self.current_epoch).cuda()}
                    self.log_dict(logs, sync_dist=False)
            else:
                self.log_dict({net_name + "allap": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True)

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = get_optimizer(self.parameters())
        # torch.optim.Adam(self.parameters(), lr=self.config.lr)
        return {
            'optimizer': optimizer,
            'lr_scheduler': get_lr_scheduler(optimizer)
        }

    def configure_callbacks(self):
        return get_extra_checkpoint_callback() + get_extra_swa_callback()


@ex.command
def get_dynamic_norm(model, dyn_norm=False):
    if not dyn
Download .txt
gitextract_blhd6h9_/

├── .github/
│   └── workflows/
│       └── pylint.yml
├── .gitignore
├── .markdownlint.json
├── LICENSE
├── README.md
├── __init__.py
├── audioset/
│   ├── README.md
│   ├── dataset.py
│   └── prepare_scripts/
│       ├── convert_to_mp3.py
│       └── create_h5pymp3_dataset.py
├── ba3l/
│   ├── __init__.py
│   ├── experiment.py
│   ├── ingredients/
│   │   ├── __init__.py
│   │   ├── datasets.py
│   │   ├── ingredient.py
│   │   ├── models.py
│   │   └── trainer.py
│   ├── module.py
│   ├── plutils/
│   │   ├── __init__.py
│   │   ├── lr_monitor.py
│   │   └── progress_bar.py
│   └── util/
│       ├── __init__.py
│       ├── functions.py
│       └── sacred_logger.py
├── config_updates.py
├── environment.yml
├── environment_old_2021.yml
├── esc50/
│   ├── README.md
│   └── dataset.py
├── ex_audioset.py
├── ex_esc50.py
├── ex_fsd50k.py
├── ex_openmic.py
├── fsd50k/
│   ├── README.md
│   ├── dataset.py
│   └── prepare_scripts/
│       ├── convert_to_mp3.py
│       └── create_h5pymp3_dataset.py
├── helpers/
│   ├── audiodatasets.py
│   ├── mixup.py
│   ├── models_size.py
│   ├── ramp.py
│   ├── swa_callback.py
│   ├── swa_legacy.py
│   └── workersinit.py
├── models/
│   ├── helpers/
│   │   └── vit_helpers.py
│   ├── passt.py
│   └── preprocess.py
├── openmic/
│   ├── README.md
│   ├── dataset.py
│   └── prepare_scripts/
│       └── download_preprocess.py
├── pip_list.txt
└── requirements.txt
Download .txt
SYMBOL INDEX (461 symbols across 34 files)

FILE: audioset/dataset.py
  function default_config (line 26) | def default_config():
  function LMODE_default_config (line 48) | def LMODE_default_config():
  function decode_mp3 (line 55) | def decode_mp3(mp3_arr):
  function pad_or_truncate (line 73) | def pad_or_truncate(x, audio_length):
  function get_ir_sample (line 85) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
  function pydub_augment (line 104) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
  class MixupDataset (line 115) | class MixupDataset(TorchDataset):
    method __init__ (line 119) | def __init__(self, dataset, beta=2, rate=0.5):
    method __getitem__ (line 125) | def __getitem__(self, index):
    method __len__ (line 139) | def __len__(self):
  class AudioSetDataset (line 143) | class AudioSetDataset(TorchDataset):
    method __init__ (line 144) | def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip...
    method open_hdf5 (line 164) | def open_hdf5(self):
    method __len__ (line 167) | def __len__(self):
    method __del__ (line 170) | def __del__(self):
    method __getitem__ (line 175) | def __getitem__(self, index):
    method resample (line 202) | def resample(self, waveform):
  function get_base_training_set (line 220) | def get_base_training_set(balanced_train_hdf5):
  function get_unbalanced_training_set (line 226) | def get_unbalanced_training_set(unbalanced_train_hdf5):
  function get_norms_dataset (line 233) | def get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5):
  function get_base_full_training_set (line 240) | def get_base_full_training_set():
  function preload_mp3 (line 247) | def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_class...
  function get_ft_cls_balanced_sample_weights (line 258) | def get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_t...
  function get_ft_weighted_sampler (line 294) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
  function get_base_test_set (line 310) | def get_base_test_set(eval_hdf5):
  function get_roll_func (line 316) | def get_roll_func(axis=1, shift=None, shift_range=50):
  function get_training_set (line 333) | def get_training_set(normalize, roll, wavmix=False):
  function get_full_training_set (line 349) | def get_full_training_set(normalize, roll, wavmix=False):
  function get_test_set (line 365) | def get_test_set(normalize):
  function print_conf (line 375) | def print_conf(_config):
  class DistributedSamplerWrapper (line 381) | class DistributedSamplerWrapper(DistributedSampler):
    method __init__ (line 382) | def __init__(
    method __iter__ (line 392) | def __iter__(self):
  function default_command (line 410) | def default_command():

FILE: audioset/prepare_scripts/convert_to_mp3.py
  function process_folder (line 15) | def process_folder(fol="balanced_train_segments"):
  function process_one (line 28) | def process_one(i, f1, f2):

FILE: audioset/prepare_scripts/create_h5pymp3_dataset.py
  function read_metadata (line 18) | def read_metadata(csv_path, classes_num, id_to_ix):
  function check_available (line 75) | def check_available(balanced_csv,balanced_audio_path,prefix=None):

FILE: ba3l/experiment.py
  function ingredients_recursive_apply (line 19) | def ingredients_recursive_apply(ing, fn):
  function config_recursive_apply (line 24) | def config_recursive_apply(conf, fn):
  function get_loggers (line 32) | def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):
  class Experiment (line 43) | class Experiment(Sacred_Experiment):
    method __init__ (line 48) | def __init__(
    method get_run_identifier (line 116) | def get_run_identifier(self):
    method get_dataloaders (line 121) | def get_dataloaders(self, filter={}):
    method get_train_dataloaders (line 130) | def get_train_dataloaders(self):
    method get_val_dataloaders (line 133) | def get_val_dataloaders(self):
    method _create_run (line 136) | def _create_run(
    method command (line 169) | def command(
    method add_default_args_config (line 205) | def add_default_args_config(self, function, prefix, extra_args={}, sta...

FILE: ba3l/ingredients/datasets.py
  function raise_ (line 14) | def raise_(ex):
  class Dataset (line 18) | class Dataset(Ingredient):
    method __init__ (line 29) | def __init__(
    method dataset (line 99) | def dataset(
    method iter (line 135) | def iter(
    method __getattr__ (line 174) | def __getattr__(self, k):
  class Datasets (line 183) | class Datasets(Ingredient, Munch):
    method get_instance (line 193) | def get_instance(cls):
    method __init__ (line 198) | def __init__(
    method __getattr__ (line 261) | def __getattr__(self, k):
    method __setattr__ (line 268) | def __setattr__(self, k, v):
    method __getitem__ (line 280) | def __getitem__(self, k):
    method __hash__ (line 294) | def __hash__(self):
    method get_datasets (line 297) | def get_datasets(self, config_conditions={}, return_datasets_names=Fal...

FILE: ba3l/ingredients/ingredient.py
  function raise_ (line 14) | def raise_(ex):
  class Ingredient (line 18) | class Ingredient(sacred_Ingredient):
    method __init__ (line 26) | def __init__(
    method add_default_args_config (line 84) | def add_default_args_config(self, function, prefix, extra_args={}, sta...
    method command (line 102) | def command(

FILE: ba3l/ingredients/models.py
  class Models (line 24) | class Models(Ingredient):
    method get_instance (line 34) | def get_instance(cls):
    method __init__ (line 39) | def __init__(
  function raise_ (line 98) | def raise_(ex):
  class Model (line 102) | class Model(Ingredient):
    method __init__ (line 112) | def __init__(
    method instance (line 176) | def instance(
    method __getattr__ (line 210) | def __getattr__(self, k):

FILE: ba3l/ingredients/trainer.py
  class Trainer (line 14) | class Trainer(Ingredient):

FILE: ba3l/module.py
  class Ba3lModule (line 33) | class Ba3lModule(pl.LightningModule):
    method __init__ (line 34) | def __init__(self, experiment):

FILE: ba3l/plutils/lr_monitor.py
  class LearningRateMonitor (line 30) | class LearningRateMonitor(Callback):
    method __init__ (line 70) | def __init__(self, logging_interval: Optional[str] = None, log_momentu...
    method on_train_start (line 79) | def on_train_start(self, trainer, *args, **kwargs):
    method on_train_batch_start (line 119) | def on_train_batch_start(self, trainer, *args, **kwargs):
    method on_train_epoch_start (line 130) | def on_train_epoch_start(self, trainer, *args, **kwargs):
    method _extract_stats (line 138) | def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
    method _extract_lr (line 158) | def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
    method _extract_momentum (line 163) | def _extract_momentum(self, param_group, name: str, use_betas: bool) -...
    method _find_names (line 171) | def _find_names(self, lr_schedulers) -> List[str]:
    method _should_log (line 204) | def _should_log(trainer) -> bool:

FILE: ba3l/plutils/progress_bar.py
  class ProgressBar (line 17) | class ProgressBar(PlProgressBar):
    method init_validation_tqdm (line 66) | def init_validation_tqdm(self) -> tqdm:
    method on_epoch_start (line 83) | def on_epoch_start(self, trainer, pl_module):
    method on_train_batch_end (line 97) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_validation_start (line 103) | def on_validation_start(self, trainer, pl_module):
    method on_validation_batch_end (line 113) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
    method on_validation_end (line 119) | def on_validation_end(self, trainer, pl_module):
  function convert_inf (line 127) | def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, f...
  function reset (line 134) | def reset(bar: tqdm, total: Optional[int] = None) -> None:

FILE: ba3l/util/functions.py
  function get_default_kwargs_dict (line 5) | def get_default_kwargs_dict(f):

FILE: ba3l/util/sacred_logger.py
  class SacredLogger (line 23) | class SacredLogger(LightningLoggerBase):
    method __init__ (line 24) | def __init__(self, sacred_experiment):
    method experiment (line 37) | def experiment(self):
    method run_id (line 41) | def run_id(self):
    method log_hyperparams (line 48) | def log_hyperparams(self, params):
    method log_metrics (line 53) | def log_metrics(self, metrics, step=None):
    method name (line 61) | def name(self):
    method version (line 65) | def version(self):
    method save (line 69) | def save(self):
    method finalize (line 76) | def finalize(self, status):

FILE: config_updates.py
  function add_configs (line 4) | def add_configs(ex):

FILE: esc50/dataset.py
  function default_config (line 29) | def default_config():
  function decode_mp3 (line 47) | def decode_mp3(mp3_arr):
  function pad_or_truncate (line 65) | def pad_or_truncate(x, audio_length):
  function get_ir_sample (line 77) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
  function pydub_augment (line 96) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
  class MixupDataset (line 107) | class MixupDataset(TorchDataset):
    method __init__ (line 111) | def __init__(self, dataset, beta=2, rate=0.5):
    method __getitem__ (line 117) | def __getitem__(self, index):
    method __len__ (line 131) | def __len__(self):
  class AudioSetDataset (line 137) | class AudioSetDataset(TorchDataset):
    method __init__ (line 138) | def __init__(self, meta_csv,  audiopath, fold, train=False, sample_rat...
    method __len__ (line 163) | def __len__(self):
    method __getitem__ (line 166) | def __getitem__(self, index):
    method resample (line 190) | def resample(self, waveform):
  function get_base_training_set (line 209) | def get_base_training_set(meta_csv, audio_path, fold=1):
  function get_ft_weighted_sampler (line 215) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
  function get_base_test_set (line 231) | def get_base_test_set(meta_csv, audio_path, fold=1):
  function get_roll_func (line 238) | def get_roll_func(axis=1, shift=None, shift_range=50):
  function get_training_set (line 255) | def get_training_set(normalize, roll, wavmix=False):
  function get_test_set (line 271) | def get_test_set(normalize):
  function print_conf (line 281) | def print_conf(_config):
  class DistributedSamplerWrapper (line 287) | class DistributedSamplerWrapper(DistributedSampler):
    method __init__ (line 288) | def __init__(
    method __iter__ (line 298) | def __iter__(self):
  function default_command (line 316) | def default_command():

FILE: ex_audioset.py
  function default_conf (line 52) | def default_conf():
  function get_scheduler_lambda (line 87) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_le...
  function get_lr_scheduler (line 98) | def get_lr_scheduler(optimizer, schedule_mode):
  function get_optimizer (line 105) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
  class M (line 112) | class M(Ba3lModule):
    method __init__ (line 113) | def __init__(self, experiment):
    method forward (line 139) | def forward(self, x):
    method mel_forward (line 142) | def mel_forward(self, x):
    method training_step (line 155) | def training_step(self, batch, batch_idx):
    method training_epoch_end (line 200) | def training_epoch_end(self, outputs):
    method predict (line 208) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
    method validation_step (line 216) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 245) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 293) | def configure_optimizers(self):
    method configure_callbacks (line 304) | def configure_callbacks(self):
  function get_dynamic_norm (line 309) | def get_dynamic_norm(model, dyn_norm=False):
  function get_extra_checkpoint_callback (line 316) | def get_extra_checkpoint_callback(save_last_n=None):
  function get_extra_swa_callback (line 323) | def get_extra_swa_callback(swa=True, swa_epoch_start=50,
  function main (line 336) | def main(_run, _config, _log, _rnd, _seed, watch_model=True):
  function model_speed_test (line 365) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
  function evaluate_only (line 430) | def evaluate_only(_run, _config, _log, _rnd, _seed):
  function test_loaders (line 445) | def test_loaders():
  function set_default_json_pickle (line 459) | def set_default_json_pickle(obj):
  function preload_mp3 (line 466) | def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
  function multiprocessing_run (line 475) | def multiprocessing_run(rank, word_size):
  function default_command (line 530) | def default_command():

FILE: ex_esc50.py
  function default_conf (line 50) | def default_conf():
  function get_scheduler_lambda (line 82) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_le...
  function get_lr_scheduler (line 92) | def get_lr_scheduler(optimizer, schedule_mode):
  function get_optimizer (line 99) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
  class M (line 106) | class M(Ba3lModule):
    method __init__ (line 107) | def __init__(self, experiment):
    method forward (line 126) | def forward(self, x):
    method mel_forward (line 129) | def mel_forward(self, x):
    method training_step (line 142) | def training_step(self, batch, batch_idx):
    method training_epoch_end (line 175) | def training_epoch_end(self, outputs):
    method predict (line 182) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
    method validation_step (line 190) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 212) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 224) | def configure_optimizers(self):
    method configure_callbacks (line 235) | def configure_callbacks(self):
  function get_dynamic_norm (line 240) | def get_dynamic_norm(model, dyn_norm=False):
  function get_extra_checkpoint_callback (line 247) | def get_extra_checkpoint_callback(save_last_n=None):
  function get_extra_swa_callback (line 254) | def get_extra_swa_callback(swa=True, swa_epoch_start=2,
  function main (line 264) | def main(_run, _config, _log, _rnd, _seed):
  function model_speed_test (line 281) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
  function evaluate_only (line 342) | def evaluate_only(_run, _config, _log, _rnd, _seed):
  function test_loaders (line 358) | def test_loaders():
  function set_default_json_pickle (line 372) | def set_default_json_pickle(obj):
  function multiprocessing_run (line 379) | def multiprocessing_run(rank, word_size):
  function default_command (line 431) | def default_command():

FILE: ex_fsd50k.py
  function variable_eval (line 54) | def variable_eval():
  function default_conf (line 60) | def default_conf():
  function get_scheduler_lambda (line 94) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_le...
  function get_lr_scheduler (line 104) | def get_lr_scheduler(optimizer, schedule_mode):
  function get_optimizer (line 111) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
  class M (line 118) | class M(Ba3lModule):
    method __init__ (line 119) | def __init__(self, experiment):
    method forward (line 138) | def forward(self, x):
    method mel_forward (line 141) | def mel_forward(self, x):
    method training_step (line 154) | def training_step(self, batch, batch_idx):
    method training_epoch_end (line 186) | def training_epoch_end(self, outputs):
    method predict (line 193) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
    method validation_step (line 201) | def validation_step(self, batch, batch_idx, dataloader_idx):
    method validation_epoch_end (line 220) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 262) | def configure_optimizers(self):
    method configure_callbacks (line 273) | def configure_callbacks(self):
  function get_dynamic_norm (line 278) | def get_dynamic_norm(model, dyn_norm=False):
  function get_extra_checkpoint_callback (line 288) | def get_extra_checkpoint_callback(save_best=None):
  function get_extra_swa_callback (line 298) | def get_extra_swa_callback(swa=True, swa_epoch_start=10,
  function main (line 308) | def main(_run, _config, _log, _rnd, _seed):
  function model_speed_test (line 331) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
  function evaluate_only (line 392) | def evaluate_only(_run, _config, _log, _rnd, _seed):
  function test_loaders (line 407) | def test_loaders():
  function set_default_json_pickle (line 421) | def set_default_json_pickle(obj):
  function preload_mp3 (line 428) | def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
  function multiprocessing_run (line 437) | def multiprocessing_run(rank, word_size):
  function default_command (line 489) | def default_command():

FILE: ex_openmic.py
  function default_conf (line 54) | def default_conf():
  function get_scheduler_lambda (line 88) | def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_le...
  function get_lr_scheduler (line 98) | def get_lr_scheduler(optimizer, schedule_mode):
  function get_optimizer (line 105) | def get_optimizer(params, lr, adamw=True, weight_decay=0.0001):
  class M (line 112) | class M(Ba3lModule):
    method __init__ (line 113) | def __init__(self, experiment):
    method forward (line 135) | def forward(self, x):
    method mel_forward (line 138) | def mel_forward(self, x):
    method training_step (line 151) | def training_step(self, batch, batch_idx):
    method training_epoch_end (line 190) | def training_epoch_end(self, outputs):
    method predict (line 197) | def predict(self, batch, batch_idx: int, dataloader_idx: int = None):
    method validation_step (line 205) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 228) | def validation_epoch_end(self, outputs):
    method configure_optimizers (line 272) | def configure_optimizers(self):
    method configure_callbacks (line 283) | def configure_callbacks(self):
  function get_dynamic_norm (line 288) | def get_dynamic_norm(model, dyn_norm=False):
  function get_extra_checkpoint_callback (line 295) | def get_extra_checkpoint_callback(save_last_n=None):
  function get_extra_swa_callback (line 302) | def get_extra_swa_callback(swa=True, swa_epoch_start=2,
  function main (line 312) | def main(_run, _config, _log, _rnd, _seed):
  function model_speed_test (line 330) | def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_...
  function evaluate_only (line 391) | def evaluate_only(_run, _config, _log, _rnd, _seed):
  function test_loaders (line 406) | def test_loaders():
  function set_default_json_pickle (line 420) | def set_default_json_pickle(obj):
  function preload_mp3 (line 427) | def preload_mp3(all_y=CMD("/basedataset.preload_mp3")):
  function multiprocessing_run (line 436) | def multiprocessing_run(rank, word_size):
  function default_command (line 488) | def default_command():

FILE: fsd50k/dataset.py
  function default_config (line 25) | def default_config():
  function LMODE_default_config (line 48) | def LMODE_default_config():
  function decode_mp3 (line 52) | def decode_mp3(mp3_arr):
  function pad_or_truncate (line 70) | def pad_or_truncate(x, audio_length):
  function get_ir_sample (line 86) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
  function pydub_augment (line 105) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
  class MixupDataset (line 116) | class MixupDataset(TorchDataset):
    method __init__ (line 120) | def __init__(self, dataset, beta=2, rate=0.5):
    method __getitem__ (line 126) | def __getitem__(self, index):
    method __len__ (line 140) | def __len__(self):
  class AudioSetDataset (line 144) | class AudioSetDataset(TorchDataset):
    method __init__ (line 145) | def __init__(self, hdf5_file, sample_rate=32000, classes_num=200, clip...
    method open_hdf5 (line 167) | def open_hdf5(self):
    method __len__ (line 170) | def __len__(self):
    method __del__ (line 173) | def __del__(self):
    method __getitem__ (line 178) | def __getitem__(self, index):
    method resample (line 205) | def resample(self, waveform):
  function get_base_training_set (line 223) | def get_base_training_set(balanced_train_hdf5, clip_length=10):
  function preload_mp3 (line 229) | def preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_class...
  function get_ft_weighted_sampler (line 242) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
  function get_base_eval_set (line 258) | def get_base_eval_set(eval_hdf5, variable_eval=None):
  function get_base_valid_set (line 268) | def get_base_valid_set(valid_hdf5, variable_eval=None):
  function get_roll_func (line 278) | def get_roll_func(axis=1, shift=None, shift_range=50):
  function get_training_set (line 295) | def get_training_set(normalize, roll, wavmix=False):
  function get_valid_set (line 311) | def get_valid_set(normalize):
  function get_eval_set (line 321) | def get_eval_set(normalize):
  function print_conf (line 331) | def print_conf(_config):
  class DistributedSamplerWrapper (line 337) | class DistributedSamplerWrapper(DistributedSampler):
    method __init__ (line 338) | def __init__(
    method __iter__ (line 348) | def __iter__(self):
  function default_command (line 366) | def default_command():

FILE: fsd50k/prepare_scripts/convert_to_mp3.py
  function process_folder (line 19) | def process_folder(fol="balanced_train_segments"):
  function process_one (line 32) | def process_one(i, f1, f2):

FILE: fsd50k/prepare_scripts/create_h5pymp3_dataset.py
  function get_labels (line 49) | def get_labels(df):

FILE: helpers/audiodatasets.py
  function h6 (line 12) | def h6(w):
  class AudioPreprocessDataset (line 15) | class AudioPreprocessDataset(Dataset):
    method __init__ (line 23) | def __init__(self, files, labels, label_encoder, base_dir, preprocesso...
    method __getitem__ (line 36) | def __getitem__(self, index):
    method get_ordered_ids (line 42) | def get_ordered_ids(self):
    method get_ordered_labels (line 45) | def get_ordered_labels(self):
    method __len__ (line 48) | def __len__(self):
  class ObjectCacher (line 51) | class ObjectCacher:
    method __init__ (line 52) | def __init__(self, get_obj_func, dataset_name, obj_name="",
    method get_obj_cache_path (line 91) | def get_obj_cache_path(self):
    method get (line 94) | def get(self):
  class PreprocessDataset (line 99) | class PreprocessDataset(Dataset):
    method __init__ (line 106) | def __init__(self, dataset, preprocessor):
    method __getitem__ (line 112) | def __getitem__(self, index):
    method __len__ (line 114) | def __len__(self):
  class FilesCachedDataset (line 119) | class FilesCachedDataset(Dataset):
    method __init__ (line 120) | def __init__(self, get_dataset_func, dataset_name, x_name="",
    method __getitem__ (line 148) | def __getitem__(self, index):
    method get_ordered_labels (line 157) | def get_ordered_labels(self):
    method get_ordered_ids (line 160) | def get_ordered_ids(self):
    method get_xcache_path (line 163) | def get_xcache_path(self):
    method get_ycache_path (line 166) | def get_ycache_path(self):
    method get_sidcache_path (line 169) | def get_sidcache_path(self):
    method __len__ (line 172) | def __len__(self):
  class SelectionDataset (line 177) | class SelectionDataset(Dataset):
    method __init__ (line 184) | def __init__(self, dataset, sample_ids):
    method reselect (line 190) | def reselect(self, sample_ids):
    method get_ordered_ids (line 194) | def get_ordered_ids(self):
    method get_ordered_labels (line 197) | def get_ordered_labels(self):
    method __getitem__ (line 202) | def __getitem__(self, index):
    method __len__ (line 205) | def __len__(self):
  class SimpleSelectionDataset (line 208) | class SimpleSelectionDataset(Dataset):
    method __init__ (line 215) | def __init__(self, dataset, available_indexes ):
    method __getitem__ (line 219) | def __getitem__(self, index):
    method __len__ (line 222) | def __len__(self):

FILE: helpers/mixup.py
  function my_mixup (line 5) | def my_mixup(size, alpha):

FILE: helpers/models_size.py
  function count_non_zero_params (line 7) | def count_non_zero_params(model):

FILE: helpers/ramp.py
  function pseudo_rampup (line 8) | def pseudo_rampup(T1, T2):
  function exp_rampup (line 21) | def exp_rampup(rampup_length):
  function linear_rampup (line 33) | def linear_rampup(rampup_length):
  function linear_rampdown (line 45) | def linear_rampdown(rampdown_length, start=0, last_value=0):
  function exp_rampdown (line 57) | def exp_rampdown(rampdown_length, num_epochs):
  function cosine_rampdown (line 70) | def cosine_rampdown(rampdown_length, num_epochs):
  function exp_warmup (line 83) | def exp_warmup(rampup_length, rampdown_length, num_epochs):
  function exp_warmup_linear_down (line 93) | def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last...
  function test_warmup (line 101) | def test_warmup():
  function test_warmupl (line 107) | def test_warmupl():
  function cosine_cycle (line 113) | def cosine_cycle(cycle_len=20,ramp_down_start=100,last_lr_value=0.01):

FILE: helpers/swa_callback.py
  class StochasticWeightAveraging (line 35) | class StochasticWeightAveraging(Callback):
    method __init__ (line 37) | def __init__(
    method swa_start (line 127) | def swa_start(self) -> int:
    method swa_end (line 131) | def swa_end(self) -> int:
    method pl_module_contains_batch_norm (line 135) | def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
    method setup (line 138) | def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'...
    method on_fit_start (line 142) | def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
    method on_train_epoch_start (line 161) | def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.L...
    method on_train_epoch_end (line 200) | def on_train_epoch_end(self, trainer: 'pl.Trainer',pl_module: 'pl.Ligh...
    method on_train_end (line 204) | def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
    method on_validation_epoch_start (line 208) | def on_validation_epoch_start(self, trainer, pl_module) -> None:
    method transfer_weights (line 217) | def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_modul...
    method reset_batch_norm_and_save_state (line 221) | def reset_batch_norm_and_save_state(self, pl_module):
    method reset_momenta (line 239) | def reset_momenta(self):
    method update_parameters (line 247) | def update_parameters(
    method avg_fn (line 262) | def avg_fn(

FILE: helpers/swa_legacy.py
  class StochasticWeightAveraging (line 37) | class StochasticWeightAveraging(Callback):
    method __init__ (line 39) | def __init__(
    method swa_start (line 132) | def swa_start(self) -> int:
    method swa_end (line 136) | def swa_end(self) -> int:
    method pl_module_contains_batch_norm (line 140) | def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
    method on_before_accelerator_backend_setup (line 143) | def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', p...
    method on_fit_start (line 147) | def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
    method on_train_epoch_start (line 172) | def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.L...
    method on_train_epoch_end (line 211) | def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.Lig...
    method on_train_end (line 214) | def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.Lightning...
    method on_validation_epoch_start (line 217) | def on_validation_epoch_start(self, trainer, pl_module) -> None:
    method transfer_weights (line 225) | def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_modul...
    method reset_batch_norm_and_save_state (line 229) | def reset_batch_norm_and_save_state(self, pl_module):
    method reset_momenta (line 247) | def reset_momenta(self):
    method update_parameters (line 255) | def update_parameters(
    method avg_fn (line 271) | def avg_fn(

FILE: helpers/workersinit.py
  function worker_init_fn (line 6) | def worker_init_fn(x):

FILE: models/helpers/vit_helpers.py
  function adapt_input_conv (line 27) | def adapt_input_conv(in_chans, conv_weight):
  function load_pretrained (line 54) | def load_pretrained(
  function overlay_external_default_cfg (line 144) | def overlay_external_default_cfg(default_cfg, kwargs):
  function filter_kwargs (line 153) | def filter_kwargs(kwargs, names):
  function set_default_kwargs (line 160) | def set_default_kwargs(kwargs, names, default_cfg):
  function update_default_cfg_and_kwargs (line 180) | def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
  function drop_path (line 203) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
  class DropPath (line 225) | class DropPath(nn.Module):
    method __init__ (line 228) | def __init__(self, drop_prob=None):
    method forward (line 232) | def forward(self, x):
  function _no_grad_trunc_normal_ (line 239) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  function trunc_normal_ (line 277) | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
  function variance_scaling_ (line 297) | def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="no...
  function lecun_normal_ (line 320) | def lecun_normal_(tensor):
  function build_model_with_cfg (line 324) | def build_model_with_cfg(

FILE: models/passt.py
  function _ntuple (line 26) | def _ntuple(n):
  function _cfg (line 42) | def _cfg(url='', **kwargs):
  function adapt_input_conv (line 246) | def adapt_input_conv(in_chans, conv_weight):
  class Mlp (line 271) | class Mlp(nn.Module):
    method __init__ (line 275) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 284) | def forward(self, x):
  class PatchEmbed (line 298) | class PatchEmbed(nn.Module):
    method __init__ (line 302) | def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3,...
    method forward (line 318) | def forward(self, x):
  class Attention (line 331) | class Attention(nn.Module):
    method __init__ (line 332) | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pro...
    method forward (line 343) | def forward(self, x):
  class Block (line 364) | class Block(nn.Module):
    method __init__ (line 366) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=...
    method forward (line 377) | def forward(self, x):
  class PaSST (line 383) | class PaSST(nn.Module):
    method __init__ (line 391) | def __init__(self, u_patchout=0, s_patchout_t=0, s_patchout_f=0, img_s...
    method init_weights (line 471) | def init_weights(self, mode=''):
    method _init_weights (line 486) | def _init_weights(self, m):
    method no_weight_decay (line 491) | def no_weight_decay(self):
    method get_classifier (line 494) | def get_classifier(self):
    method reset_classifier (line 500) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 506) | def forward_features(self, x):
    method forward (line 576) | def forward(self, x):
  function _init_vit_weights (line 598) | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: floa...
  function resize_pos_embed (line 633) | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='...
  function adapt_image_pos_embed_to_passt (line 656) | def adapt_image_pos_embed_to_passt(posemb, num_tokens=1, gs_new=(), mode...
  function checkpoint_filter_fn (line 679) | def checkpoint_filter_fn(state_dict, model):
  function _create_vision_transformer (line 709) | def _create_vision_transformer(variant, pretrained=False, default_cfg=No...
  function vit_huge_patch14_224_in21k (line 734) | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
  function deit_base_distilled_patch16_384 (line 745) | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
  function passt_s_swa_p16_128_ap476 (line 756) | def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
  function passt_s_kd_p16_128_ap486 (line 769) | def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
  function passt_l_kd_p16_128_ap47 (line 782) | def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
  function passt_s_swa_p16_128_ap4761 (line 795) | def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
  function passt_s_p16_128_ap472 (line 808) | def passt_s_p16_128_ap472(pretrained=False, **kwargs):
  function passt_s_p16_s12_128_ap470 (line 821) | def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
  function passt_s_f128_20sec_p16_s10_ap474_swa (line 834) | def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
  function passt_s_f128_30sec_p16_s10_ap473_swa (line 842) | def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
  function passt_s_swa_p16_s12_128_ap473 (line 850) | def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
  function passt_s_p16_s14_128_ap469 (line 863) | def passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):
  function passt_s_swa_p16_s14_128_ap471 (line 876) | def passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):
  function passt_s_swa_p16_s16_128_ap473 (line 889) | def passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):
  function passt_s_p16_s16_128_ap468 (line 902) | def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
  function fix_embedding_layer (line 923) | def fix_embedding_layer(model, embed="default"):
  function lighten_model (line 933) | def lighten_model(model, cut_depth=0):
  function get_model (line 958) | def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classe...
  class EnsembelerModel (line 1021) | class EnsembelerModel(nn.Module):
    method __init__ (line 1022) | def __init__(self, models):
    method forward (line 1026) | def forward(self, x):
  function get_ensemble_model (line 1040) | def get_ensemble_model(arch_list=[]):

FILE: models/preprocess.py
  class AugmentMelSTFT (line 19) | class AugmentMelSTFT(nn.Module):
    method __init__ (line 20) | def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, ...
    method forward (line 57) | def forward(self, x):
    method extra_repr (line 88) | def extra_repr(self):

FILE: openmic/dataset.py
  function default_config (line 29) | def default_config():
  function decode_mp3 (line 47) | def decode_mp3(mp3_arr):
  function pad_or_truncate (line 65) | def pad_or_truncate(x, audio_length):
  function get_ir_sample (line 77) | def get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):
  function pydub_augment (line 96) | def pydub_augment(waveform, gain_augment=7, ir_augment=0):
  class MixupDataset (line 107) | class MixupDataset(TorchDataset):
    method __init__ (line 111) | def __init__(self, dataset, beta=2, rate=0.5):
    method __getitem__ (line 117) | def __getitem__(self, index):
    method __len__ (line 140) | def __len__(self):
  class AudioSetDataset (line 145) | class AudioSetDataset(TorchDataset):
    method __init__ (line 146) | def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip...
    method open_hdf5 (line 166) | def open_hdf5(self):
    method __len__ (line 169) | def __len__(self):
    method __del__ (line 172) | def __del__(self):
    method __getitem__ (line 177) | def __getitem__(self, index):
    method resample (line 203) | def resample(self, waveform):
  function get_base_training_set (line 221) | def get_base_training_set(openmic_train_hdf5):
  function get_ft_weighted_sampler (line 227) | def get_ft_weighted_sampler(samples_weights=CMD(".get_ft_cls_balanced_sa...
  function get_base_test_set (line 243) | def get_base_test_set(openmic_test_hdf5):
  function get_roll_func (line 249) | def get_roll_func(axis=1, shift=None, shift_range=50):
  function get_training_set (line 266) | def get_training_set(normalize, roll, wavmix=False):
  function get_test_set (line 282) | def get_test_set(normalize):
  function print_conf (line 292) | def print_conf(_config):
  class DistributedSamplerWrapper (line 298) | class DistributedSamplerWrapper(DistributedSampler):
    method __init__ (line 299) | def __init__(
    method __iter__ (line 309) | def __iter__(self):
  function default_command (line 327) | def default_command():

FILE: openmic/prepare_scripts/download_preprocess.py
  function download (line 21) | def download(force=False):
  function untar (line 29) | def untar():
  function process_folder (line 36) | def process_folder(fol="balanced_train_segments"):
  function process_one (line 50) | def process_one(i, f1, f2):
  function make_mp3 (line 56) | def make_mp3():
  function read_metadata (line 60) | def read_metadata(csv_path, classes_num, id_to_ix, openmicf):
  function get_files_labels (line 95) | def get_files_labels(balanced_csv, balanced_audio_path, d_files, openmic...
  function pack (line 118) | def pack():
  function preprocess (line 164) | def preprocess():
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (364K chars).
[
  {
    "path": ".github/workflows/pylint.yml",
    "chars": 553,
    "preview": "name: Pylint\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: "
  },
  {
    "path": ".gitignore",
    "chars": 1540,
    "preview": "\n\n# project\n/output/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\nexample.py\ntimit_data/\nL"
  },
  {
    "path": ".markdownlint.json",
    "chars": 119,
    "preview": "{\n    \"MD033\": {\n        \"allowed_elements\": [\n            \"p\",\n            \"img\"\n        ]\n    },\n    \"MD013\": false\n}"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 14253,
    "preview": "# PaSST: Efficient Training of Audio Transformers with Patchout\n\nThis is the implementation for [Efficient Training of A"
  },
  {
    "path": "__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "audioset/README.md",
    "chars": 2323,
    "preview": "# Experiments on Audioset\nAudioset has around 2M segments. The total size of the dataset with wav files with `32khz` sam"
  },
  {
    "path": "audioset/dataset.py",
    "chars": 13659,
    "preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
  },
  {
    "path": "audioset/prepare_scripts/convert_to_mp3.py",
    "chars": 1955,
    "preview": "import argparse\nimport multiprocessing\nimport glob\nimport os\n\n# Replace this with the dataset downloaded using PANN scri"
  },
  {
    "path": "audioset/prepare_scripts/create_h5pymp3_dataset.py",
    "chars": 5839,
    "preview": "# %%\nimport h5py\nimport pandas as pd\nimport numpy as np\nimport csv\nimport os\n\n\n# %%\nbase_dir = \"../../audioset_hdf5s/\"\nb"
  },
  {
    "path": "ba3l/__init__.py",
    "chars": 331,
    "preview": "\"\"\"Package info\"\"\"\n\n__version__ = \"0.0.2\"\n__author__ = \"Koutini et al.\"\n__license__ = \"Apache-2.0\"\n__copyright__ = \"Copy"
  },
  {
    "path": "ba3l/experiment.py",
    "chars": 8033,
    "preview": "import inspect\nfrom importlib import import_module\n\nfrom ba3l.ingredients.datasets import Datasets\nfrom ba3l.ingredients"
  },
  {
    "path": "ba3l/ingredients/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ba3l/ingredients/datasets.py",
    "chars": 11133,
    "preview": "import inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sac"
  },
  {
    "path": "ba3l/ingredients/ingredient.py",
    "chars": 4965,
    "preview": "import inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sac"
  },
  {
    "path": "ba3l/ingredients/models.py",
    "chars": 6967,
    "preview": "import inspect\nimport os\n\nfrom .ingredient import Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.u"
  },
  {
    "path": "ba3l/ingredients/trainer.py",
    "chars": 440,
    "preview": "import inspect\nimport os\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom .datasets import Datasets, raise_"
  },
  {
    "path": "ba3l/module.py",
    "chars": 1158,
    "preview": "import pytorch_lightning as pl\n\nimport warnings\nfrom abc import ABC\n\nimport torch.distributed as dist\nfrom munch import "
  },
  {
    "path": "ba3l/plutils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ba3l/plutils/lr_monitor.py",
    "chars": 8110,
    "preview": "# Copyright The PyTorch Lightning team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may no"
  },
  {
    "path": "ba3l/plutils/progress_bar.py",
    "chars": 6006,
    "preview": "\nimport importlib\nimport io\nimport os\nimport sys\n\n# check if ipywidgets is installed before importing tqdm.auto\n# to ens"
  },
  {
    "path": "ba3l/util/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ba3l/util/functions.py",
    "chars": 286,
    "preview": "import inspect\nfrom collections import OrderedDict\n\n\ndef get_default_kwargs_dict(f):\n    sig = inspect.signature(f)\n    "
  },
  {
    "path": "ba3l/util/sacred_logger.py",
    "chars": 2507,
    "preview": "from pytorch_lightning.utilities import rank_zero_only\ntry:\n    from pytorch_lightning.loggers import Logger as Lightnin"
  },
  {
    "path": "config_updates.py",
    "chars": 11551,
    "preview": "from sacred.config_helpers import DynamicIngredient, CMD\n\n\ndef add_configs(ex):\n    '''\n    This functions add generic c"
  },
  {
    "path": "environment.yml",
    "chars": 2413,
    "preview": "name: ba3l\nchannels:\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1\n  - _openmp_mutex=5.1\n  - ca-certificates=2022.10."
  },
  {
    "path": "environment_old_2021.yml",
    "chars": 3487,
    "preview": "name: ba3l\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1\n  - _openmp_mutex=4.5\n  - _pytorch"
  },
  {
    "path": "esc50/README.md",
    "chars": 1975,
    "preview": "# Experiments on ESC-50 Environmental Sound Classification\n[ESC-50](https://github.com/karolpiczak/ESC-50) consist of 20"
  },
  {
    "path": "esc50/dataset.py",
    "chars": 10201,
    "preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
  },
  {
    "path": "ex_audioset.py",
    "chars": 19493,
    "preview": "import os\nimport sys\nimport PIL\nimport pytorch_lightning\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheck"
  },
  {
    "path": "ex_esc50.py",
    "chars": 15106,
    "preview": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers im"
  },
  {
    "path": "ex_fsd50k.py",
    "chars": 18025,
    "preview": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers im"
  },
  {
    "path": "ex_openmic.py",
    "chars": 17683,
    "preview": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers im"
  },
  {
    "path": "fsd50k/README.md",
    "chars": 3060,
    "preview": "# Experiments on FSD50K\n The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432))  consists of 51K audio clips a"
  },
  {
    "path": "fsd50k/dataset.py",
    "chars": 12238,
    "preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
  },
  {
    "path": "fsd50k/prepare_scripts/convert_to_mp3.py",
    "chars": 1305,
    "preview": "import multiprocessing\nimport glob\nimport os\nimport sys\n\nif len(sys.argv) > 1:\n    FSD50K_base = sys.argv[1] # the path "
  },
  {
    "path": "fsd50k/prepare_scripts/create_h5pymp3_dataset.py",
    "chars": 2753,
    "preview": "# %%\nimport sys\n\nimport h5py\nimport pandas as pd\nimport numpy as np\nimport csv\nimport os\n\n# %%\nfrom numpy import dtype\n\n"
  },
  {
    "path": "helpers/audiodatasets.py",
    "chars": 7509,
    "preview": "import hashlib\nimport os\nimport time\n\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom os.path import expanduser\n"
  },
  {
    "path": "helpers/mixup.py",
    "chars": 400,
    "preview": "\nimport numpy as np\nimport torch\n\ndef my_mixup(size, alpha):\n    rn_indices = torch.randperm(size)\n    lambd = np.random"
  },
  {
    "path": "helpers/models_size.py",
    "chars": 1069,
    "preview": "\n\n\n\n\n\ndef count_non_zero_params(model):\n    sum_params = 0\n    sum_non_zero = 0\n    desc = \"\"\n\n    def calc_params(model"
  },
  {
    "path": "helpers/ramp.py",
    "chars": 3635,
    "preview": "import numpy as np\nfrom ba3l.ingredients.ingredient import Ingredient\n\n\n# credit: https://github.com/iBelieveCJM/Tricks-"
  },
  {
    "path": "helpers/swa_callback.py",
    "chars": 11325,
    "preview": "# Adapted from PyTorch Lightning so that it only does the averaging\n# Copyright The PyTorch Lightning team.\n#\n# Licensed"
  },
  {
    "path": "helpers/swa_legacy.py",
    "chars": 11761,
    "preview": "# Adapted from PyTorch Lightning so that it only does the averaging\n# Copyright The PyTorch Lightning team.\n#\n# Licensed"
  },
  {
    "path": "helpers/workersinit.py",
    "chars": 251,
    "preview": "import torch\nimport numpy as np\nimport random\n\n\ndef worker_init_fn(x):\n    seed = (torch.initial_seed() + x * 1000) % 2 "
  },
  {
    "path": "models/helpers/vit_helpers.py",
    "chars": 15583,
    "preview": "\"\"\"\nAdapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\nCredit "
  },
  {
    "path": "models/passt.py",
    "chars": 52057,
    "preview": "\"\"\"\nMost of this code comes from the timm  library.\nWe tried to disentangle from the timm library version.\n\nAdapted from"
  },
  {
    "path": "models/preprocess.py",
    "chars": 3638,
    "preview": "import torch.nn as nn\nimport torchaudio\nfrom torch.nn.functional import conv1d, conv2d\n\nimport torch\n\n\nfrom ba3l.ingredi"
  },
  {
    "path": "openmic/README.md",
    "chars": 884,
    "preview": "# Experiments on OpenMIC-2018\n\n[OpenMIC-2018](https://github.com/cosmir/openmic-2018) ([zenodo](https://zenodo.org/recor"
  },
  {
    "path": "openmic/dataset.py",
    "chars": 10659,
    "preview": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data impor"
  },
  {
    "path": "openmic/prepare_scripts/download_preprocess.py",
    "chars": 6389,
    "preview": "import os\nimport tarfile\nimport multiprocessing\nimport glob\nimport h5py\nimport numpy as np\n\nfrom torch.hub import downlo"
  },
  {
    "path": "pip_list.txt",
    "chars": 3141,
    "preview": "Package                 Version\n----------------------- ------------\nabsl-py                 1.4.0\naiohttp              "
  },
  {
    "path": "requirements.txt",
    "chars": 151,
    "preview": "av>=10.0.0\nh5py>=3.8.0\njsonpickle>=3.0.1\nkk-sacred>=0.8.4\nlibrosa>=0.10.0.post2\ntimm>=0.4.12\ntorchmetrics>=0.2.0\npytorch"
  }
]

About this extraction

This page contains the full source code of the kkoutini/PaSST GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (341.1 KB), approximately 89.4k tokens, and a symbol index with 461 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!