[
  {
    "path": ".github/workflows/pylint.yml",
    "content": "name: Pylint\n\non: [push]\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.8\", \"3.9\", \"3.10\"]\n    steps:\n    - uses: actions/checkout@v3\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v3\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install pylint\n    - name: Analysing the code with pylint\n      run: |\n        pylint $(git ls-files '*.py')\n"
  },
  {
    "path": ".gitignore",
    "content": "\n\n# project\n/output/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\nexample.py\ntimit_data/\nLJSpeech-1.1/\n\n# C extensions\n*.so\n\n.idea/\n\n# Distribution / packaging\n.Python\nenv/\nide_layouts/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n/openmic/prepare_scripts/openmic-2018-v1.0.0/\n/openmic/prepare_scripts/openmic-2018-v1.0.0.tgz\n/audioset_hdf5s/mp3/openmic_test.csv_mp3.hdf\n/audioset_hdf5s/mp3/openmic_train.csv_mp3.hdf\nenvironment_builds.yml\n\n# Output\nlightning_logs/*\naudioset_hdf5s/*\n.vscode/settings.json\nwandb/*\n.vscode/launch.json\n"
  },
  {
    "path": ".markdownlint.json",
    "content": "{\n    \"MD033\": {\n        \"allowed_elements\": [\n            \"p\",\n            \"img\"\n        ]\n    },\n    \"MD013\": false\n}"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# PaSST: Efficient Training of Audio Transformers with Patchout\n\nThis is the implementation for [Efficient Training of Audio Transformers with Patchout](https://arxiv.org/abs/2110.05069)\n\nPatchout significantly reduces the training time and GPU memory requirements to train transformers on audio spectrograms, while improving their performance.\n\n<p align=\"center\"><img src=\"https://github.com/kkoutini/PaSST/blob/main/.github/speed_mem_map.png?raw=true\" width=\"600\"/></p>\n\nPatchout works by dropping out some of the input patches during training.\n In either an unstructured way (randomly, similar to dropout),\nor entire time-frames or frequency bins of the extracted patches (similar to SpecAugment),\n which corresponds to rows/columns in step 3 of the figure below.  \n\n<p align=\"center\"><img src=\"https://github.com/kkoutini/PaSST/raw/main/.github/passt_diag.png?raw=true\" width=\"600\"/></p>\n\n## Table of contents\n\n- [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions)\n  - [Getting the logits from the pretrained models](#getting-the-logits-from-the-pretrained-models)\n  - [Getting a pre-trained model for fine-tuning](#getting-a-pre-trained-model-for-fine-tuning)\n- [Development environment](#development-environment)\n  - [Setting up the development experiments environment](#setting-up-the-development-experiments-environment)\n  - [Setting up using the exported conda environment](#setting-up-using-the-exported-conda-environment)\n  - [Checking the environment](#checking-the-environment)\n- [Getting started](#getting-started)\n  - [General information](#general-information)\n  - [Configuring the experiment](#configuring-the-experiment)\n- [Training on Audioset](#training-on-audioset)\n- [Examples with Pre-trained models](#examples-with-pre-trained-models)\n- [Examples fine-tuning on downstream datasets](#examples-of-fine-tuning-on-downstream-datasets)\n- [Citation](#citation)\n- [Contact](#contact)\n\n## Pre-trained models for Inference and embeddings extractions\n\nIf you only want to use the embeddings generated by the pretrained models, use\nyour own fine-tuning framework, or you need it only for inference, you can find a stripped down version of this repo [here](https://github.com/kkoutini/passt_hear21).\nThe package follows [HEAR 2021 NeurIPS Challenge](https://neuralaudio.ai/hear2021-results.html) API, and can be installed:\n\n```shell\npip install hear21passt\n```\n\nThis repo is a complete framework for training the models and fine-tuning pre-trained models on Audioset on downstream tasks.\n\n### Getting the logits from the pretrained models\n\n```python\nfrom hear21passt.base import get_basic_model,get_model_passt\nimport torch\n# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer\nmodel = get_basic_model(mode=\"logits\")\nprint(model.mel) # Extracts mel spectrogram from raw waveforms.\nprint(model.net) # the transformer network.\n\n# example inference\nmodel.eval()\nmodel = model.cuda()\nwith torch.no_grad():\n    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k\n    # example audio_wave of batch=3 and 10 seconds\n    audio = torch.ones((3, 32000 * 10))*0.5\n    audio_wave = audio.cuda()\n    logits=model(audio_wave) \n```\n\n### Getting a pre-trained model for fine tuning\n\n```python\nfrom hear21passt.base import get_basic_model,get_model_passt\nimport torch\n# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer\nmodel = get_basic_model(mode=\"logits\")\nprint(model.mel) # Extracts mel spectrogram from raw waveforms.\n\n# optional replace the transformer with one that has the required number of classes i.e. 50\nmodel.net = get_model_passt(arch=\"passt_s_swa_p16_128_ap476\",  n_classes=50)\nprint(model.net) # the transformer network.\n\n\n# now model contains mel + the transformer pre-trained model ready to be fine tuned.\n# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k\n\nmodel.train()\nmodel = model.cuda()\n\n```\n\n## Development environment\n\nIf you want to use the same environment as in the paper, you can follow the instructions below.\n\n### Setting up the development experiments environment\n\nFor training models from scratch or fine-tuning using the same setup as in the paper:\n\n1. If needed, create a new environment with python 3.8 and activate it:\n\n```bash\nconda create -n passt python=3.8\nconda activate passt\n ```\n\n1. Install pytorch build that suits your system. For example:\n\n```bash\nconda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch\n\n ```\n\n1. Install the requirements:\n\n ```bash\npip install -r requirements.txt\n ```\n\n### Setting up using the exported conda environment\n\nAlternatively, you can use the exported conda environment `environment.yml` to create the environment.\n\nFor setting up [Mamba](https://github.com/mamba-org/mamba) is recommended since it works faster than `conda`:\n\n```shell\nconda install mamba -n base -c conda-forge\n```\n\nNow you can import the environment from `environment.yml`\n\n```shell\nmamba env create -f environment.yml\n```\n\nNow you have an environment named `ba3l`.\n\n### Checking the environment\n\nIn order to check if your environment matched the environment we used in our runs, please check the `environment.yml` and `pip_list.txt` files, which were exported using:\n\n```shell\nconda env export --no-builds | grep -v \"prefix\" > environment.yml\npip list > pip_list.txt\n```\n\n## Getting started\n\nIf you want to use your setup and only use the models from this repo, you can get the models train them from scratch or fine-tune them on your own dataset as explained above [Pre-trained models for Inference and embeddings extractions](#pre-trained-models-for-inference-and-embeddings-extractions). The rest of this section explains using this repo for training and fine-tuning the models. For that, first you need to set up the development environment as explained above.\n\n### General information\n\nThe repo is built using [sacred](https://sacred.readthedocs.io/en/) for experiment management and configuration, pytorch-lightning for training, and wandb for logging.\n\nEach dataset has a main experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder. The experiment file contains the main training and validation logic. The dataset folder contains the dataset specific code needed to download, preprocess and load the dataset for training.\n\nIn general, you can prob the experiment file for help, this will print the available commands and basic options:\n\n```shell\npython ex_audioset.py help\n```\n\n### Configuring the experiment\n\nEach experiment has a set of default configuration options, defined in the experiment file, e.g. `ex_audioset.py`. You can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html). You can use the `print_config` command to print the configuration values without training a model:\n\n```shell\n python ex_audioset.py print_config\n ```\n\nYou can use then use the command line interface to override any of the configuration options ([sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html)), using `with` e.g.:\n\n```shell\npython ex_audioset.py with trainer.precision=16 \n```\n\nThis will train on Audioset using 16-bit precision.\n\nThe overall configurations look like this:\n\n```yaml\n  ...\n  seed = 542198583                  # the random seed for this experiment\n  slurm_job_id = ''\n  speed_test_batch_size = 100\n  swa = True\n  swa_epoch_start = 50\n  swa_freq = 5\n  use_mixup = True\n  warm_up_len = 5\n  weight_decay = 0.0001\n  basedataset:\n    base_dir = 'audioset_hdf5s/'     # base directory of the dataset, change it or make a link\n    eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'\n    wavmix = 1\n    ....\n    roll_conf:\n      axis = 1\n      shift = None\n      shift_range = 50\n  datasets:\n    test:\n      batch_size = 20\n      dataset = {CMD!}'/basedataset.get_test_set'\n      num_workers = 16\n      validate = True\n    training:\n      batch_size = 12\n      dataset = {CMD!}'/basedataset.get_full_training_set'\n      num_workers = 16\n      sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'\n      shuffle = None\n      train = True\n  models:\n    mel:\n      freqm = 48\n      timem = 192\n      hopsize = 320\n      htk = False\n      n_fft = 1024\n      n_mels = 128\n      norm = 1\n      sr = 32000\n      ...\n    net:\n      arch = 'passt_s_swa_p16_128_ap476'\n      fstride = 10\n      in_channels = 1\n      input_fdim = 128\n      input_tdim = 998\n      n_classes = 527\n      s_patchout_f = 4\n      s_patchout_t = 40\n      tstride = 10\n      u_patchout = 0\n      ...\n  trainer:\n    accelerator = None\n    accumulate_grad_batches = 1\n    amp_backend = 'native'\n    amp_level = 'O2'\n    auto_lr_find = False\n    auto_scale_batch_size = False\n    ...\n```\n\nThere are many things that can be updated from the command line.\nIn short:\n\n- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line.\n- `wandb` is the wandb configuration. For example, to change the wandb project `wandb.project=\"test_project\"` to the command line.\n- `models.net` are the PaSST (or the chosen NN) options. Examples: `models.net.u_patchout`,  `models.net.s_patchout_f` `models.net.s_patchout_t` control the unstructured patchout and structured patchout over frequency and time. `input_fdim` and `input_tdim` are the input spectrogram dimensions over frequency and time. `models.net.fstride` and `models.net.tstride` are the strides of the input patches over frequency and time, setting these to 16 means no patch overlap.\n- `models.mel` are the preprocessing options (mel spectrograms). `mel.sr` is the sampling rate, `mel.hopsize` is the hop size of the STFT window, `mel.n_mels` is the number of mel bins, `mel.freqm` and `mel.timem` are the frequency and time masking parameters of spec-augment.\n\nThere are many pre-defined configuration bundles (called named_configs) in `config_updates.py`. These include different models, setups etc...\nYou can list these configurations with:\n\n```shell\npython ex_audioset.py print_named_configs\n```\n\nFor example, `passt_s_20sec` is a configuration bundle that sets the model to PaSST-S pre-trained on Audioset, and accepts up to 20 second clips.\n\n## Training on Audioset\n\nDownload and prepare the dataset as explained in the [audioset page](audioset/)\n\nThe base PaSST model can be trained for example like this:\n\n```bash\npython ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p\n```\n\nFor example using only unstructured patchout of 400:\n\n```bash\npython ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384  models.net.u_patchout=400  models.net.s_patchout_f=0 models.net.s_patchout_t=0 -p\n```\n\nMulti-gpu training can be enabled by setting the environment variable `DDP`, for example with 2 gpus:\n\n```shell\n DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -c \"PaSST base 2 GPU\"\n```\n\n## Examples with Pre-trained models\n\nPlease check the [releases page](https://github.com/kkoutini/PaSST/releases/), to download pre-trained models.\nIn general, you can get a pretrained model on Audioset using\n\n```python\nfrom models.passt import get_model\nmodel  = get_model(arch=\"passt_s_swa_p16_128_ap476\", pretrained=True, n_classes=527, in_channels=1,\n                   fstride=10, tstride=10,input_fdim=128, input_tdim=998,\n                   u_patchout=0, s_patchout_t=40, s_patchout_f=4)\n```\n\nthis will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs.\n\nThere are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`.\nFor example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479.\n\nIn general, you can get a pretrained model using:\n\n```python\nfrom models.passt import get_model\npasst = get_model(arch=\"passt_s_swa_p16_s16_128_ap473\", fstride=16, tstride=16)\n```\n\nUsing the framework, you can evaluate this model using:\n\n```shell\npython ex_audioset.py evaluate_only with  trainer.precision=16  passt_s_swa_p16_s16_128_ap473 -p\n```\n\nEnsemble of these models are provided as well:\nA large ensemble giving `mAP=.4956`\n\n```shell\npython ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many\n```\n\nAn ensemble of 2 models with `stride=14` and `stride=16` giving `mAP=.4858`\n\n```shell\npython ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s16_14\n```\n\nAs well as other ensembles `ensemble_4`, `ensemble_5`\n\n## Examples of fine-tuning on downstream datasets\n\n1. [ESC-50: Dataset for Environmental Sound Classification](esc50/)\n2. [OpenMIC-2018 dataset](openmic/)\n3. [FSD50K](fsd50k/)\n\n## Citation\n\nThe citation to the accepted paper in Interspeech 2022:\n\n```bib\n@inproceedings{koutini22passt,\n  author       = {Khaled Koutini and\n                  Jan Schl{\\\"{u}}ter and\n                  Hamid Eghbal{-}zadeh and\n                  Gerhard Widmer},\n  title        = {Efficient Training of Audio Transformers with Patchout},\n  booktitle    = {Interspeech 2022, 23rd Annual Conference of the International Speech\n                  Communication Association, Incheon, Korea, 18-22 September 2022},\n  pages        = {2753--2757},\n  publisher    = {{ISCA}},\n  year         = {2022},\n  url          = {https://doi.org/10.21437/Interspeech.2022-227},\n  doi          = {10.21437/Interspeech.2022-227},\n}\n```\n\n## Contact\n\nThe repo will be updated, in the meantime if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.\n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "audioset/README.md",
    "content": "# Experiments on Audioset\nAudioset has around 2M segments. The total size of the dataset with wav files with `32khz` sampling rate is around 1.2 TB. In our setup, this results in a huge IO bottleneck that slows down the training process significantly.\nTherefore, we encode the dataset to mp3, pack the mp3 into HDF5 format and decode the mp3s on the fly, If you have enough cpu cores (10-16 dataloading workers) you should not notice any slowdowns.\n\nin the `dataset.py` file we read the samples from the hdf files. Decode the mp3, do wave form augmentations and return the raw waveform of the model.\n`AudioSetDataset` is the main class where reading from the hdf files.\n\n\n\n## Preparing the dataset\n###  Downloading Audioset\nWe used the scripts provided by [PANNS](https://github.com/qiuqiangkong/audioset_tagging_cnn) to download the dataset.\n###  Converting to mp3\nOnce the Datasets are downloaded we convert all the files to mp3 using the script:\n`prepare_scripts/convert_to_mp3.py`.\n\n```bash\npython convert_to_mp3.py --source pann_download_folder --out mp3_folder\n```\n\nthis will significantly reduce the size of the dataset and overcome the IO bottleneck in our setup. The trade-off is that more cpu is needed during training to decode the mp3s. \nWe use the [av](https://pypi.org/project/av/) (check `decode_mp3` in ` dataset.py`) library to decode the mp3 in the data loading workers, this is much faster than calling ffmpeg.\nAs a result, approximetly 10 decoding threads should be enough keep a 2080ti busy.\n\nyou can test how much time it take to load and decode one epoch on your system:\n```bash\npython python ex_audioset.py test_loaders_train_speed\n```\n\nThis step is not necessary if you have a more powerful setup and the `decode_mp3` also supports other ffmpeg codecs.\n\n###  packing to HDF5 files\n\nFinally, you need to pack the mp3 files into a single HDF5 file using `create_h5pymp3_dataset.py`.\nyou just need to set the paths in the script to match your local paths. The script goes through the csv files and check if the corresponding mp3 file exists, then it will store it in h5py file.\nThe output of this step should be 3 files `balanced_train_segments_mp3.hdf`, `eval_segments_mp3.hdf` and `unbalanced_train_segments_mp3.hdf`.\nEach of these files. Make sure the paths match the default config in `dataset.py`"
  },
  {
    "path": "audioset/dataset.py",
    "content": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler\n\nimport torch\nfrom ba3l.ingredients.datasets import Dataset\nfrom sacred.config import DynamicIngredient, CMD\nfrom scipy.signal import convolve\nimport numpy as np\nfrom helpers.audiodatasets import  PreprocessDataset\nimport h5py\n\n\nLMODE = os.environ.get(\"LMODE\", False)\n#$TMPDIR\ndataset = Dataset('audiodataset')\n\n\n@dataset.config\ndef default_config():\n    name = 'audioset'  # dataset name\n    normalize = False  # normalize dataset\n    subsample = False  # subsample squares from the dataset\n    roll = True  # apply roll augmentation\n    fold = 1\n    base_dir = \"audioset_hdf5s/\"  # base directory of the dataset, change it or make a link\n    if LMODE:\n        base_dir = \"/system/user/publicdata/CP/audioset/audioset_hdf5s/\"\n\n    balanced_train_hdf5 = base_dir + \"mp3/balanced_train_segments_mp3.hdf\"\n    eval_hdf5 = base_dir + \"mp3/eval_segments_mp3.hdf\"\n    unbalanced_train_hdf5 = base_dir + \"mp3/unbalanced_train_segments_mp3.hdf\"\n    if LMODE:\n        balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get(\"TMPDIR\", base_dir)+\"/\")\n        unbalanced_train_hdf5 = unbalanced_train_hdf5.replace(base_dir, os.environ.get(\"TMPDIR\", base_dir)+\"/\")\n        eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get(\"TMPDIR\", base_dir)+\"/\")\n    ir_path = base_dir + \"irs/\"\n    num_of_classes = 527\n\nif LMODE:\n    @dataset.config\n    def LMODE_default_config():\n        cache_root_path = \"/system/user/publicdata/CP/DCASE/cached_datasets/\"\n\n\n\n\n\ndef decode_mp3(mp3_arr):\n    \"\"\"\n    decodes an array if uint8 representing an mp3 file\n    :rtype: np.array\n    \"\"\"\n    container = av.open(io.BytesIO(mp3_arr.tobytes()))\n    stream = next(s for s in container.streams if s.type == 'audio')\n    # print(stream)\n    a = []\n    for i, packet in enumerate(container.demux(stream)):\n        for frame in packet.decode():\n            a.append(frame.to_ndarray().reshape(-1))\n    waveform = np.concatenate(a)\n    if waveform.dtype != 'float32':\n        raise RuntimeError(\"Unexpected wave type\")\n    return waveform\n\n\ndef pad_or_truncate(x, audio_length):\n    \"\"\"Pad all audio to specific length.\"\"\"\n    if len(x) <= audio_length:\n        return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)\n    else:\n        return x[0: audio_length]\n\n\nirs_arr = None\n\n\n@dataset.command\ndef get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):\n    if not ir_augment:\n        return\n    global irs_arr\n    if irs_arr is None:\n        all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]\n        all_paths = sorted(all_paths)\n        if cut_irs_offset is not None:\n            all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]\n        all_paths_name = [str(p).rsplit(\"/\", 1)[-1] for p in all_paths]\n        print(\"will use these IRs:\")\n        for i in range(len(all_paths_name)):\n            print(i, \": \", all_paths_name[i])\n        _run.info[\"ir_devices\"] = all_paths_name\n        irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]\n    return irs_arr[int(np.random.randint(0, len(irs_arr)))]\n\n\n@dataset.command\ndef pydub_augment(waveform, gain_augment=7, ir_augment=0):\n    if ir_augment and torch.rand(1) < ir_augment:\n        ir = get_ir_sample()\n        waveform = convolve(waveform, ir, 'full')\n    if gain_augment:\n        gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment\n        amp = 10 ** (gain / 20)\n        waveform = waveform * amp\n    return waveform\n\n\nclass MixupDataset(TorchDataset):\n    \"\"\" Mixing Up wave forms\n    \"\"\"\n\n    def __init__(self, dataset, beta=2, rate=0.5):\n        self.beta = beta\n        self.rate = rate\n        self.dataset = dataset\n        print(f\"Mixing up waveforms from dataset of len {len(dataset)}\")\n\n    def __getitem__(self, index):\n        if torch.rand(1) < self.rate:\n            x1, f1, y1 = self.dataset[index]\n            idx2 = torch.randint(len(self.dataset), (1,)).item()\n            x2, f2, y2 = self.dataset[idx2]\n            l = np.random.beta(self.beta, self.beta)\n            l = max(l, 1. - l)\n            x1 = x1-x1.mean()\n            x2 = x2-x2.mean()\n            x = (x1 * l + x2 * (1. - l))\n            x = x - x.mean()\n            return x, f1, (y1 * l + y2 * (1. - l))\n        return self.dataset[index]\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nclass AudioSetDataset(TorchDataset):\n    def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False):\n        \"\"\"\n        Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav\n        \"\"\"\n        self.sample_rate = sample_rate\n        self.hdf5_file = hdf5_file\n        if in_mem:\n            print(\"\\nPreloading in memory\\n\")\n            with open(hdf5_file, 'rb') as f:\n                self.hdf5_file = io.BytesIO(f.read())\n        with h5py.File(hdf5_file, 'r') as f:\n            self.length = len(f['audio_name'])\n            print(f\"Dataset from {hdf5_file} with length {self.length}.\")\n        self.dataset_file = None  # lazy init\n        self.clip_length = clip_length * sample_rate\n        self.classes_num = classes_num\n        self.augment = augment\n        if augment:\n            print(f\"Will agument data from {hdf5_file}\")\n\n    def open_hdf5(self):\n        self.dataset_file = h5py.File(self.hdf5_file, 'r')\n\n    def __len__(self):\n        return self.length\n\n    def __del__(self):\n        if self.dataset_file is not None:\n            self.dataset_file.close()\n            self.dataset_file = None\n\n    def __getitem__(self, index):\n        \"\"\"Load waveform and target of an audio clip.\n\n        Args:\n          meta: {\n            'hdf5_path': str,\n            'index_in_hdf5': int}\n        Returns:\n          data_dict: {\n            'audio_name': str,\n            'waveform': (clip_samples,),\n            'target': (classes_num,)}\n        \"\"\"\n        if self.dataset_file is None:\n            self.open_hdf5()\n\n        audio_name = self.dataset_file['audio_name'][index].decode()\n        waveform = decode_mp3(self.dataset_file['mp3'][index])\n        if self.augment:\n            waveform = pydub_augment(waveform)\n        waveform = pad_or_truncate(waveform, self.clip_length)\n        waveform = self.resample(waveform)\n        target = self.dataset_file['target'][index]\n        target = np.unpackbits(target, axis=-1,\n                               count=self.classes_num).astype(np.float32)\n        return waveform.reshape(1, -1), audio_name, target\n\n    def resample(self, waveform):\n        \"\"\"Resample.\n        Args:\n          waveform: (clip_samples,)\n        Returns:\n          (resampled_clip_samples,)\n        \"\"\"\n        if self.sample_rate == 32000:\n            return waveform\n        elif self.sample_rate == 16000:\n            return waveform[0:: 2]\n        elif self.sample_rate == 8000:\n            return waveform[0:: 4]\n        else:\n            raise Exception('Incorrect sample rate!')\n\n\n@dataset.command\ndef get_base_training_set(balanced_train_hdf5):\n    ds = AudioSetDataset(balanced_train_hdf5, augment=True)\n    return ds\n\n\n@dataset.command\ndef get_unbalanced_training_set(unbalanced_train_hdf5):\n    ds = AudioSetDataset(unbalanced_train_hdf5, augment=True)\n    return ds\n\n\n\n@dataset.command\ndef get_norms_dataset(unbalanced_train_hdf5, balanced_train_hdf5):\n    ds = ConcatDataset(\n        [AudioSetDataset(balanced_train_hdf5, augment=False), AudioSetDataset(unbalanced_train_hdf5, augment=False)])\n    return ds\n\n\n@dataset.command\ndef get_base_full_training_set():\n    sets = [get_base_training_set(), get_unbalanced_training_set()]\n    ds = ConcatDataset(sets)\n    return ds\n\n\n@dataset.command\ndef preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes):\n    for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:\n        print(f\"\\n \\n will now preload {hdf5_file} \\n\\n \")\n        with h5py.File(hdf5_file, 'r') as dataset_file:\n            target = dataset_file['mp3'][:]\n            print(len(target))\n            print(f\"\\n \\n done with  {hdf5_file} \\n\\n \")\n    return target[1000]\n\n\n@dataset.command\ndef get_ft_cls_balanced_sample_weights(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes,\n                                       sample_weight_offset=100, sample_weight_sum=True):\n    \"\"\"\n    :return: float tenosr of shape len(full_training_set) representing the weights of each sample.\n    \"\"\"\n    # the order of balanced_train_hdf5,unbalanced_train_hdf5 is important.\n    # should match get_full_training_set\n    all_y = []\n    for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:\n        with h5py.File(hdf5_file, 'r') as dataset_file:\n            target = dataset_file['target']\n            target = np.unpackbits(target, axis=-1,\n                                   count=num_of_classes)\n            all_y.append(target)\n    all_y = np.concatenate(all_y, axis=0)\n    all_y = torch.as_tensor(all_y)\n    per_class = all_y.long().sum(0).float().reshape(1, -1)  # frequencies per class\n\n    per_class = sample_weight_offset + per_class  # offset low freq classes\n    if sample_weight_offset > 0:\n        print(f\"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}\")\n    per_class_weights = 1000. / per_class\n    all_weight = all_y * per_class_weights\n    # print(all_weight.shape)\n    # print(all_weight[1510])\n    if sample_weight_sum:\n        print(\"\\nsample_weight_sum\\n\")\n        all_weight = all_weight.sum(dim=1)\n    else:\n        all_weight, _ = all_weight.max(dim=1)\n    # print(all_weight.shape)\n    # print(all_weight[1510])\n    return all_weight\n\n\n@dataset.command\ndef get_ft_weighted_sampler(samples_weights=CMD(\".get_ft_cls_balanced_sample_weights\"),\n                            epoch_len=100000, sampler_replace=False):\n    num_nodes = int(os.environ.get('num_nodes', 1))\n    ddp = int(os.environ.get('DDP', 1))\n    num_nodes = max(ddp, num_nodes)\n    print(\"num_nodes= \", num_nodes)\n    rank = int(os.environ.get('NODE_RANK', 0))\n    return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,\n                                                                   num_samples=epoch_len, replacement=sampler_replace),\n                                     dataset=range(epoch_len),\n                                     num_replicas=num_nodes,\n                                     rank=rank,\n                                     )\n\n\n@dataset.command\ndef get_base_test_set(eval_hdf5):\n    ds = AudioSetDataset(eval_hdf5)\n    return ds\n\n\n@dataset.command(prefix='roll_conf')\ndef get_roll_func(axis=1, shift=None, shift_range=50):\n    print(\"rolling...\")\n\n    def roll_func(b):\n        x, i, y = b\n        x = torch.as_tensor(x)\n        sf = shift\n        if shift is None:\n            sf = int(np.random.random_integers(-shift_range, shift_range))\n        global FirstTime\n\n        return x.roll(sf, axis), i, y\n\n    return roll_func\n\n\n@dataset.command\ndef get_training_set(normalize, roll, wavmix=False):\n    ds = get_base_training_set()\n    get_ir_sample()\n    if normalize:\n        print(\"normalized train!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    if roll:\n        ds = PreprocessDataset(ds, get_roll_func())\n    if wavmix:\n        ds = MixupDataset(ds)\n\n    return ds\n\n\n@dataset.command\ndef get_full_training_set(normalize, roll, wavmix=False):\n    ds = get_base_full_training_set()\n    get_ir_sample()\n    if normalize:\n        print(\"normalized train!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    if roll:\n        ds = PreprocessDataset(ds, get_roll_func())\n    if wavmix:\n        ds = MixupDataset(ds)\n    return ds\n\n\n\n@dataset.command\ndef get_test_set(normalize):\n    ds = get_base_test_set()\n    if normalize:\n        print(\"normalized test!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    return ds\n\n\n@dataset.command\ndef print_conf(_config):\n    print(\"Config of \", dataset.path, id(dataset))\n    print(_config)\n    print()\n\n\nclass DistributedSamplerWrapper(DistributedSampler):\n    def __init__(\n            self, sampler, dataset,\n            num_replicas=None,\n            rank=None,\n            shuffle: bool = True):\n        super(DistributedSamplerWrapper, self).__init__(\n            dataset, num_replicas, rank, shuffle)\n        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238\n        self.sampler = sampler\n\n    def __iter__(self):\n        if self.sampler.generator is None:\n            self.sampler.generator = torch.Generator()\n        self.sampler.generator.manual_seed(self.seed + self.epoch)\n        indices = list(self.sampler)\n        if self.epoch == 0:\n            print(f\"\\n DistributedSamplerWrapper :  {indices[:10]} \\n\\n\")\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        return iter(indices)\n\n\nif __name__ == \"__main__\":\n    from sacred import Experiment\n\n    ex = Experiment(\"test_dataset\", ingredients=[dataset])\n\n\n    @ex.automain\n    def default_command():\n        ex.current_run.get_command_function(\"print_config\")()\n        get_base_training_set()\n        ds = get_test_set()\n        print(ds[0])\n        ds = get_training_set()\n        print(ds[0])\n        print(\"get_base_training_set\", len(get_base_training_set()))\n        print(\"get_base_test_set\", len(get_base_test_set()))\n        print(\"get_training_set\", len(get_training_set()))\n        print(\"get_test_set\", len(get_test_set()))\n"
  },
  {
    "path": "audioset/prepare_scripts/convert_to_mp3.py",
    "content": "import argparse\nimport multiprocessing\nimport glob\nimport os\n\n# Replace this with the dataset downloaded using PANN scripts\nsource_path = \"/share/cp/datasets/full_audioset/audioset201906/audios/\"\n\n# Replace with the output directory\nout_path = \"./mp3_audio/\"\n\nall_num = 0\n\n\ndef process_folder(fol=\"balanced_train_segments\"):\n    print(\"now working on \", fol)\n    os.makedirs(out_path + fol, exist_ok=True)\n    all_files = list(glob.glob(source_path + fol + \"/*.wav\"))\n    print(f\"it has {len(all_files)}\")\n    global all_num\n    all_num = len(all_files)\n    cmds = [(i, file, out_path + fol + \"/\" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]\n    print(cmds[0])\n    with multiprocessing.Pool(processes=20) as pool:\n        pool.starmap(process_one, cmds)\n\n\ndef process_one(i, f1, f2):\n    if i % 100 == 0:\n        print(f\"{i}/{all_num} \\t\", f1)\n    os.system(f\"ffmpeg  -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3\")\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--source', type=str, required=False, default=None,\n                        help='Path of folder containing the wave files, expected format to be as downloaded '\n                             'using the PANNs script, containing balanced_train_segments, eval_segments, '\n                             'unbalanced_train_segments folders.')\n    parser.add_argument('--out', type=str, required=False, default=None,\n                        help='Directory to save out the converted mp3s.')\n\n    args = parser.parse_args()\n\n    source_path = args.source or source_path\n    out_path = args.out or out_path\n    folders = ['balanced_train_segments', 'eval_segments'] + [\"unbalanced_train_segments/\" + x for x in sorted(\n        os.listdir(source_path + \"unbalanced_train_segments/\"))]\n\n    print(\"I will work on these folders:\")\n    print(folders)\n\n    for fol in folders:\n        process_folder(fol)\n"
  },
  {
    "path": "audioset/prepare_scripts/create_h5pymp3_dataset.py",
    "content": "# %%\nimport h5py\nimport pandas as pd\nimport numpy as np\nimport csv\nimport os\n\n\n# %%\nbase_dir = \"../../audioset_hdf5s/\"\nbalanced_csv= base_dir+ \"metadata/balanced_train_segments.csv\"\neval_csv= base_dir+ \"metadata/eval_segments.csv\"\nmp3_path = \"../../mp3_audio/\"\n\n\n# %%\n\ndef read_metadata(csv_path, classes_num, id_to_ix):\n    \"\"\"Read metadata of AudioSet from a csv file.\n    source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59\n    Args:\n      csv_path: str\n    Returns:\n      meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}\n    \"\"\"\n\n    with open(csv_path, 'r') as fr:\n        lines = fr.readlines()\n        lines = lines[3:]   # Remove heads\n\n    audios_num = len(lines)\n    targets = np.zeros((audios_num, classes_num), dtype=np.bool)\n    audio_names = []\n \n    for n, line in enumerate(lines):\n        items = line.split(', ')\n        \"\"\"items: ['--4gqARaEJE', '0.000', '10.000', '\"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk\"\\n']\"\"\"\n\n        audio_name = 'Y{}.mp3'.format(items[0])   # Audios are started with an extra 'Y' when downloading\n        label_ids = items[3].split('\"')[1].split(',')\n\n        audio_names.append(audio_name)\n\n        # Target\n        for id in label_ids:\n            ix = id_to_ix[id]\n            targets[n, ix] = 1\n\n    meta_dict = {'audio_name': np.array(audio_names), 'target': targets}\n    return meta_dict\n\n# Load label\nwith open(base_dir+'metadata/class_labels_indices.csv', 'r') as f:\n    reader = csv.reader(f, delimiter=',')\n    lines = list(reader)\n\nlabels = []\nids = []    # Each label has a unique id such as \"/m/068hy\"\nfor i1 in range(1, len(lines)):\n    id = lines[i1][1]\n    label = lines[i1][2]\n    ids.append(id)\n    labels.append(label)\n\nclasses_num = len(labels)\n\nlb_to_ix = {label : i for i, label in enumerate(labels)}\nix_to_lb = {i : label for i, label in enumerate(labels)}\n\nid_to_ix = {id : i for i, id in enumerate(ids)}\nix_to_id = {i : id for i, id in enumerate(ids)}\n\n# %%\n\ndef check_available(balanced_csv,balanced_audio_path,prefix=None):\n    meta_csv = read_metadata(balanced_csv,classes_num,id_to_ix)\n    audios_num = len(meta_csv['audio_name'])\n    found=0\n    notfound=0\n    available_files=[]\n    available_targets=[]\n    if prefix is None:\n        prefix = os.path.basename(balanced_csv)[:-4]\n    for n in range(audios_num):\n        audio_path =  meta_csv['audio_name'][n]\n        #print(balanced_audio_path + f\"{prefix}/{audio_path}\")\n        if os.path.isfile(balanced_audio_path + f\"{prefix}/{audio_path}\" ):\n            found+=1\n            available_files.append(meta_csv['audio_name'][n])\n            available_targets.append(meta_csv['target'][n])\n        else:\n            notfound+=1\n    print(f\"Found {found} . not found {notfound}\")\n    return available_files,available_targets\n# %%\n\n# %%\n\n# %%\n\n\n\nfor read_file,prefix in [(balanced_csv,\"balanced_train_segments/\"), (eval_csv,\"eval_segments/\"),]:\n    print(\"now working on \",read_file,prefix)\n    #files, y = torch.load(read_file+\".pth\")\n    files, y = check_available(read_file, mp3_path)\n    y = np.packbits(y, axis=-1)\n    packed_len = y.shape[1]\n    print(files[0], \"classes: \",packed_len, y.dtype)\n    available_size = len(files)\n    f = files[0][:-3]+\"mp3\"\n    a = np.fromfile(mp3_path+prefix + \"/\"+f, dtype='uint8')\n\n    dt = h5py.vlen_dtype(np.dtype('uint8'))\n    save_file = prefix.split(\"/\")[0]\n    with h5py.File(base_dir+ \"mp3/\" + save_file+\"_mp3.hdf\", 'w') as hf:\n        audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')\n        waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)\n        target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)\n        for i,file in enumerate(files):\n            if i%1000==0:\n                print(f\"{i}/{available_size}\")\n            f = file[:-3] + \"mp3\"\n            a = np.fromfile(mp3_path + prefix  + f, dtype='uint8')\n            audio_name[i]=f\n            waveform[i] = a\n            target[i] = y[i]\n\n    print(a.shape)\n    print(\"Done!\" , prefix)\n\n\n# %%\nprint(\"working on unbalanced...\")\n\n\n\nall_x,all_y = None, None\nfor idx in  range(41):\n    print(\"working on \",idx)\n    tmp_csv = base_dir+ f\"metadata/unbalanced_partial_csvs/unbalanced_train_segments_part{idx:02}.csv\"\n    prefix = f\"unbalanced_train_segments/unbalanced_train_segments_part{idx:02}\"\n    x,y = check_available(tmp_csv,mp3_path,prefix=prefix)\n    x = np.array([f\"{idx:02}/\"+one for one in x])\n    y=np.packbits(y, axis=-1)\n    print(\"x,y\",x.shape, y.shape)\n    if all_x is None:\n        all_x = x\n        all_y = y\n    else:\n        all_x = np.concatenate((all_x,x))\n        all_y = np.concatenate((all_y,y))\n    print(f\"done {idx}! all x,y\",all_x.shape, all_y.shape)\n\n\n\nprint(\"now working on packing  unbalanced\")\nprefix = \"unbalanced_train_segments/unbalanced_train_segments_part\"\nfiles = all_x\ny = all_y\npacked_len = y.shape[1]\nprint(files[0], \"classes: \",packed_len, y.dtype)\navailable_size = len(files)\nf = files[0][:-3]+\"mp3\"\na = np.fromfile(mp3_path+prefix + f, dtype='uint8')\n\ndt = h5py.vlen_dtype(np.dtype('uint8'))\nsave_file = prefix.split(\"/\")[0]\nwith h5py.File(base_dir+ \"mp3/\" + save_file+\"_mp3.hdf\", 'w') as hf:\n    audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')\n    waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)\n    target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)\n    for i,file in enumerate(files):\n        if i%1000==0:\n            print(f\"{i}/{available_size}\")\n        f = file[:-3] + \"mp3\"\n        a = np.fromfile(mp3_path + prefix  + f, dtype='uint8')\n        audio_name[i]=f\n        waveform[i] = a\n        target[i] = y[i]\n\nprint(a.shape)\nprint(\"Done!\" , prefix)\n\n\n"
  },
  {
    "path": "ba3l/__init__.py",
    "content": "\"\"\"Package info\"\"\"\n\n__version__ = \"0.0.2\"\n__author__ = \"Koutini et al.\"\n__license__ = \"Apache-2.0\"\n__copyright__ = \"Copyright (c) 2019-, %s.\" % __author__\n__homepage__ = \"https://github.com//\"\n__docs__ = (\n    \"The researcher friendly pytorch environment. Ba3l= sacred+pytorch-lightening.\"\n)\n__author_email__ = \"first.last@jku.at\"\n"
  },
  {
    "path": "ba3l/experiment.py",
    "content": "import inspect\nfrom importlib import import_module\n\nfrom ba3l.ingredients.datasets import Datasets\nfrom ba3l.ingredients.models import Models, Model\n#from ba3l.trainer import Trainer\nfrom ba3l.util.sacred_logger import SacredLogger\nfrom sacred import Experiment as Sacred_Experiment, Ingredient\nfrom typing import Sequence, Optional, List\n\nfrom sacred.commandline_options import CLIOption\nfrom sacred.config import CMD\nfrom sacred.host_info import HostInfoGetter\nfrom sacred.utils import PathType, optional_kwargs_decorator\nfrom pytorch_lightning import loggers as pl_loggers\nfrom ba3l.util.functions import get_default_kwargs_dict\n\n\ndef ingredients_recursive_apply(ing, fn):\n    fn(ing)\n    for kid in ing.ingredients:\n        ingredients_recursive_apply(kid, fn)\n\ndef config_recursive_apply(conf, fn):\n    for k,v in conf.items():\n        if isinstance(v, dict):\n            config_recursive_apply(v,fn)\n        else:\n            fn(k,v)\n\n\ndef get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):\n    loggers = []\n    if use_sacred_logger:\n        loggers.append( SacredLogger(expr))\n    if use_tensorboard_logger:\n        loggers.append(pl_loggers.TensorBoardLogger(sacred_logger.name))\n    \n    return loggers\n\n\n\nclass Experiment(Sacred_Experiment):\n    \"\"\"\n    Main Ba3l Experiment class overrides sacred experiments.\n    \"\"\"\n\n    def __init__(\n        self,\n        name: Optional[str] = None,\n        ingredients: Sequence[Ingredient] = (),\n        datasets: Optional[Ingredient] = None,\n        models: Optional[Ingredient] = None,\n        interactive: bool = False,\n        base_dir: Optional[PathType] = None,\n        additional_host_info: Optional[List[HostInfoGetter]] = None,\n        additional_cli_options: Optional[Sequence[CLIOption]] = None,\n        save_git_info: bool = True,\n    ):\n        \"\"\"\n        Create a new experiment with the given name and optional ingredients. (from Sacred)\n\n\n        Parameters\n        ----------\n        name\n            Optional name of this experiment, defaults to the filename.\n            (Required in interactive mode)\n\n        ingredients : list[sacred.Ingredient], optional\n            A list of ingredients to be used with this experiment.\n\n        interactive\n            If set to True will allow the experiment to be run in interactive\n            mode (e.g. IPython or Jupyter notebooks).\n            However, this mode is discouraged since it won't allow storing the\n            source-code or reliable reproduction of the runs.\n\n        base_dir\n            Optional full path to the base directory of this experiment. This\n            will set the scope for automatic source file discovery.\n\n        additional_host_info\n            Optional dictionary containing as keys the names of the pieces of\n            host info you want to collect, and as\n            values the functions collecting those pieces of information.\n\n        save_git_info:\n            Optionally save the git commit hash and the git state\n            (clean or dirty) for all source files. This requires the GitPython\n            package.\n        \"\"\"\n        if models is None:\n            models = Models.get_instance()\n        self.models = models\n        if datasets is None:\n            datasets = Datasets.get_instance()\n        self.datasets = datasets\n        if ingredients is None:\n            ingredients = []\n        ingredients = list(ingredients) + [models, datasets]\n        caller_globals = inspect.stack()[1][0].f_globals\n        self.last_default_configuration_position = 0\n        super().__init__(\n            name=name,\n            ingredients=ingredients,\n            interactive=interactive,\n            base_dir=base_dir,\n            additional_host_info=additional_host_info,\n            additional_cli_options=additional_cli_options,\n            save_git_info=save_git_info,\n            caller_globals=caller_globals\n        )\n\n\n    def get_run_identifier(self):\n        return str(self.current_run.db_identifier) \\\n            + \"_\" + str(self.current_run._id)\n\n\n    def get_dataloaders(self, filter={}):\n        results = {}\n        for ds in self.datasets.get_datasets(filter):\n            results[ds.name] = ds.get_iterator()\n        if len(results) == 1:\n            for k, v in results.items():\n                return v\n        return results\n\n    def get_train_dataloaders(self):\n        return self.get_dataloaders(dict(train=True))\n\n    def get_val_dataloaders(self):\n        return self.get_dataloaders(dict(validate=True))\n\n    def _create_run(\n        self,\n        command_name=None,\n        config_updates=None,\n        named_configs=(),\n        info=None,\n        meta_info=None,\n        options=None,\n        dry_run=False,\n    ):\n        if self.current_run is not None:\n            # @todo replace with logger\n            print(\"Warning: multiple runs are not yet supported\")\n\n\n        run = super()._create_run(\n            command_name,\n            config_updates,\n            named_configs,\n            info,\n            meta_info,\n            options,\n            dry_run=False,\n        )\n        # self.current_run=run\n        # def update_current_run(ing):\n        #     ing.current_run = run\n        #\n        # ingredients_recursive_apply(self, update_current_run)\n\n        return run\n\n    @optional_kwargs_decorator\n    def command(\n            self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={},\n            **extra_args\n    ):\n        \"\"\"\n        Decorator to define a new Command.\n\n        a command is a function whose parameters are filled automatically by sacred.\n\n        The command can be given a prefix, to restrict its configuration space\n        to a subtree. (see ``capture`` for more information)\n\n        A command can be made unobserved (i.e. ignoring all observers) by\n        passing the unobserved=True keyword argument.\n        :param add_default_args_config: wether to add the default arguments of the function to the config automatically.\n        :param function: the function to return a Dataset Object\n        :param prefix: sacred configuration prefix\n        :param unobserved: sacred unobserved\n        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and\n         are not stored in the config\n        :param extra_args: explicit arguments to be add to the config, you can these to override the function default\n        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command\n        specified by CMD string value. CMD string have special context\n        :return:\n\n\n        \"\"\"\n        add_default_args_config = (not unobserved) and add_default_args_config\n        if add_default_args_config:\n            self.add_default_args_config(function, prefix, extra_args, static_args=static_args)\n        captured_f = self.capture(function, prefix=prefix, static_args=static_args)\n        captured_f.unobserved = unobserved\n        self.commands[function.__name__] = captured_f\n        return captured_f\n\n\n    def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):\n        \"\"\"\n        adds the default parameters of a function to the ingredient config at lowest priority!\n        Default args config is meant remove the need to declare all the configurations manually.\n        :param f: the function\n        \"\"\"\n        # @todo get the doc of the params as well\n        config_candidate = {**get_default_kwargs_dict(function), **extra_args}\n        # remove \"static_args\" from config\n        for k in static_args:\n            config_candidate.pop(k, None)\n        # respect the prefix for the added default parameters\n        if prefix is not None:\n            for pr in prefix.split('.')[::-1]:\n                config_candidate={pr: config_candidate}\n\n        self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))\n        self.last_default_configuration_position += 1\n"
  },
  {
    "path": "ba3l/ingredients/__init__.py",
    "content": ""
  },
  {
    "path": "ba3l/ingredients/datasets.py",
    "content": "import inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sacred.config import CMD\nfrom .ingredient import Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.utils import PathType, optional_kwargs_decorator\nfrom munch import DefaultFactoryMunch, Munch\n\ndef raise_(ex):\n    raise ex\n\n\nclass Dataset(Ingredient):\n    \"\"\"\n    The class that annotates a Dateset of Ba3l experiment\n    a Dataset can be\n\n\n    \"\"\"\n\n    DATASET_STRING_PREFIX = \"get_dataset\"\n    ITER_STRING_PREFIX = \"get_iterator\"\n\n    def __init__(\n        self,\n        name: str,\n        ingredients: Sequence[Ingredient] = (),\n        interactive: bool = False,\n        base_dir: Optional[PathType] = None,\n        save_git_info: bool = True,\n    ):\n        \"\"\"\n        Create a new experiment with the given name and optional ingredients.\n\n        Parameters\n        ----------\n        name\n            Optional name of this experiment, defaults to the filename.\n            (Required in interactive mode)\n\n        ingredients : list[sacred.Ingredient], optional\n            A list of ingredients to be used with this experiment.\n\n        interactive\n            If set to True will allow the experiment to be run in interactive\n            mode (e.g. IPython or Jupyter notebooks).\n            However, this mode is discouraged since it won't allow storing the\n            source-code or reliable reproduction of the runs.\n\n        base_dir\n            Optional full path to the base directory of this experiment. This\n            will set the scope for automatic source file discovery.\n\n        additional_host_info\n            Optional dictionary containing as keys the names of the pieces of\n            host info you want to collect, and as\n            values the functions collecting those pieces of information.\n\n        save_git_info:\n            Optionally save the git commit hash and the git state\n            (clean or dirty) for all source files. This requires the GitPython\n            package.\n        \"\"\"\n\n        caller_globals = inspect.stack()[1][0].f_globals\n        if name is None:\n            name = \"dataset\"\n        self.name = name.rsplit(\".\", 1)[-1]\n        super().__init__(\n            path=name,\n            ingredients=ingredients,\n            interactive=interactive,\n            base_dir=base_dir,\n            _caller_globals=caller_globals,\n            save_git_info=save_git_info,\n        )\n\n        self.get_dataset_command = None\n        self.get_dataset_iterator_command = None\n        self.current_run = None\n        self.get_dataset = lambda: raise_(\n            NotImplementedError(\n                \"Use dataset.dataset_name.dataset to annotate the  \"\n                \"get_dataset function!.\"\n            )\n        )\n        self.get_iter = lambda: raise_(\n            NotImplementedError(\n                \"Use dataset.dataset_name.iter to annotate the  \" \"get_iter function!.\"\n            )\n        )\n\n    @optional_kwargs_decorator\n    def dataset(\n        self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args\n    ):\n        \"\"\"\n        Decorator to define a new Dataset.\n\n        The name of the dataset is used to get an instance of the dataset, it will register a command\n\n        Datasets are sacred commands.\n\n        The command can be given a prefix, to restrict its configuration space\n        to a subtree. (see ``capture`` for more information)\n\n        A command can be made unobserved (i.e. ignoring all observers) by\n        passing the unobserved=True keyword argument.\n        :param function: the function to return a Dataset Object\n        :param prefix: sacred configuration prefix\n        :param unobserved: sacred unobserved\n        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and\n         are not stored in the config\n        :param extra_args: explicit arguments to be add to the config, you can these to override the function default\n        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command\n        specified by CMD string value. CMD string have special context\n        :return:\n\n\n        \"\"\"\n        self.add_default_args_config(function, prefix, extra_args, static_args=static_args)\n        captured_f = self.capture(function, prefix=prefix, static_args=static_args)\n        captured_f.unobserved = unobserved\n        self.commands[Dataset.DATASET_STRING_PREFIX] = captured_f\n        self.get_dataset = captured_f\n        self.add_config(dataset=CMD(\"get_dataset\"))\n        return captured_f\n\n    @optional_kwargs_decorator\n    def iter(\n        self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args\n    ):\n        \"\"\"\n        Decorator to define a new Iterator.\n\n        The name of the iterator is used to get an instance of the iterator, it will register a command\n\n        iterator are sacred commands.\n\n        The command can be given a prefix, to restrict its configuration space\n        to a subtree. (see ``capture`` for more information)\n\n        A command can be made unobserved (i.e. ignoring all observers) by\n        passing the unobserved=True keyword argument.\n                :param function: the function to return a Dataset Object\n        :param prefix: sacred configuration prefix\n        :param unobserved: sacred unobserved\n        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and\n         are not stored in the config\n        :param extra_args: explicit arguments to be add to the config, you can these to override the function default\n        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command\n        specified by CMD string value. CMD string have special context\n        \"\"\"\n        self.add_default_args_config(function, prefix, extra_args, static_args=static_args)\n\n        captured_f = self.capture(function, prefix=prefix, static_args=static_args)\n        captured_f.unobserved = unobserved\n        self.commands[Dataset.ITER_STRING_PREFIX] = captured_f\n        self.get_iter = captured_f\n        return captured_f\n\n    # def get_dataset(self):\n    #     assert self.current_run is not None, \"Can only be called during a run.\"\n    #     return self.commands[Datasets.DATASET_STRING_PREFIX + name]()\n    #     # return self.current_run.get_command_function(\n    #     #     self.path + \".\" + Datasets.DATASET_STRING_PREFIX + name)()\n    #     #\n\n    def __getattr__(self, k):\n        if k == \"iterator\":\n            return self.__getattribute__(\"iter\")\n        if k == \"get_iterator\":\n            return self.__getattribute__(\"get_iter\")\n        super().__getattribute__(k)\n        # @todo maybe run commands from here after running\n\n\nclass Datasets(Ingredient, Munch):\n    \"\"\"\n    The class that encapsulates all the datasets in an experiment\n\n\n    \"\"\"\n\n    __instance = None\n\n    @classmethod\n    def get_instance(cls):\n        if Datasets.__instance is None:\n            Datasets.__instance = Datasets()\n        return Datasets.__instance\n\n    def __init__(\n        self,\n        name: Optional[str] = None,\n        ingredients: Sequence[Ingredient] = (),\n        interactive: bool = False,\n        base_dir: Optional[PathType] = None,\n        save_git_info: bool = True,\n    ):\n        \"\"\"\n        Create a new experiment with the given name and optional ingredients.\n\n        Parameters\n        ----------\n        name\n            Optional name of this experiment, defaults to the filename.\n            (Required in interactive mode)\n\n        ingredients : list[sacred.Ingredient], optional\n            A list of ingredients to be used with this experiment.\n\n        interactive\n            If set to True will allow the experiment to be run in interactive\n            mode (e.g. IPython or Jupyter notebooks).\n            However, this mode is discouraged since it won't allow storing the\n            source-code or reliable reproduction of the runs.\n\n        base_dir\n            Optional full path to the base directory of this experiment. This\n            will set the scope for automatic source file discovery.\n\n        additional_host_info\n            Optional dictionary containing as keys the names of the pieces of\n            host info you want to collect, and as\n            values the functions collecting those pieces of information.\n\n        save_git_info:\n            Optionally save the git commit hash and the git state\n            (clean or dirty) for all source files. This requires the GitPython\n            package.\n        \"\"\"\n\n        caller_globals = inspect.stack()[1][0].f_globals\n        if name is None:\n            name = \"datasets\"\n\n        Ingredient.__init__(\n            self,\n            path=name,\n            ingredients=ingredients,\n            interactive=interactive,\n            base_dir=base_dir,\n            _caller_globals=caller_globals,\n            save_git_info=save_git_info,\n        )\n\n        self.get_datasets_list_command = None\n        self.get_dataset_command = None\n        self.get_dataset_iterator_command = None\n        self.current_run = None\n        self.get_dataset = None\n\n        # self.command(get_dataset_iterator_command, unobserved=True)\n\n    def __getattr__(self, k):\n        \"\"\" Gets key if it exists, otherwise returns the default value.\"\"\"\n        try:\n            return Munch.__getattr__(self, k)\n        except AttributeError:\n            return self.__getitem__(k)\n\n    def __setattr__(self, k, v):\n        try:\n            # Throws exception if not in prototype chain\n            object.__getattribute__(self, k)\n        except AttributeError:\n            try:\n                self[k] = v\n            except:\n                raise AttributeError(k)\n        else:\n            object.__setattr__(self, k, v)\n\n    def __getitem__(self, k):\n        \"\"\" Gets key if it exists, otherwise returns the default value.\"\"\"\n        try:\n            return Munch.__getitem__(self, k)\n        except KeyError:\n            self[k] = Dataset(\n                self.path + \".\" + k,\n                base_dir=self.base_dir,\n                save_git_info=self.save_git_info,\n            )\n            assert self\n            self.ingredients.append(self[k])\n            return self[k]\n\n    def __hash__(self):\n        return Ingredient.__hash__(self)\n\n    def get_datasets(self, config_conditions={}, return_datasets_names=False):\n        \"\"\" Return all the datasets whose configuration matches config_conditions.\"\"\"\n        results = []\n        for dataset in self.ingredients:\n            all_ok = True\n            for cond_k, cond_v in config_conditions.items():\n                if (\n                    self.current_run.get_config_path_value(dataset.path + \".\" + cond_k)\n                    != cond_v\n                ):\n                    all_ok = False\n                    break\n            if all_ok:\n                if return_datasets_names:\n                    results.append((dataset))\n                results.append(dataset)\n        return results\n"
  },
  {
    "path": "ba3l/ingredients/ingredient.py",
    "content": "import inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sacred import Ingredient as sacred_Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.utils import PathType, optional_kwargs_decorator\nfrom munch import DefaultFactoryMunch, Munch\n\n\ndef raise_(ex):\n    raise ex\n\n\nclass Ingredient(sacred_Ingredient):\n    \"\"\"\n    The class that annotates a Dateset of Ba3l experiment\n    a Dataset can be\n\n\n    \"\"\"\n\n    def __init__(\n        self,\n        path: str,\n        ingredients: Sequence[sacred_Ingredient] = (),\n        interactive: bool = False,\n        _caller_globals: Optional[dict] = None,\n        base_dir: Optional[PathType] = None,\n        save_git_info: bool = True,\n    ):\n        \"\"\"\n        The Base Ingredient of all Ba3l ingredients\n\n         Parameters\n         ----------\n         path\n             Optional name of this experiment, defaults to the filename.\n             (Required in interactive mode)\n\n         ingredients : list[sacred.Ingredient], optional\n             A list of ingredients to be used with this experiment.\n\n         interactive\n             If set to True will allow the experiment to be run in interactive\n             mode (e.g. IPython or Jupyter notebooks).\n             However, this mode is discouraged since it won't allow storing the\n             source-code or reliable reproduction of the runs.\n\n         base_dir\n             Optional full path to the base directory of this experiment. This\n             will set the scope for automatic source file discovery.\n\n         additional_host_info\n             Optional dictionary containing as keys the names of the pieces of\n             host info you want to collect, and as\n             values the functions collecting those pieces of information.\n\n         save_git_info:\n             Optionally save the git commit hash and the git state\n             (clean or dirty) for all source files. This requires the GitPython\n             package.\n        \"\"\"\n\n        _caller_globals = _caller_globals or inspect.stack()[1][0].f_globals\n        if path is None:\n            path = \"Ingredient\"\n\n        super().__init__(\n            path=path,\n            ingredients=ingredients,\n            interactive=interactive,\n            base_dir=base_dir,\n            _caller_globals=_caller_globals,\n            save_git_info=save_git_info,\n        )\n\n        self.current_run = None\n        self.last_default_configuration_position = 0\n\n    def add_default_args_config(self, function, prefix, extra_args={}, static_args={}):\n        \"\"\"\n        adds the default parameters of a function to the ingredient config at lowest priority!\n        Default args config is meant remove the need to declare all the configurations manually.\n        :param f: the function\n        \"\"\"\n        # @todo get the doc of the params as well\n        config_candidate = {**get_default_kwargs_dict(function), **extra_args}\n        # remove \"static_args\" from config\n        for k in static_args:\n            config_candidate.pop(k, None)\n        if prefix is not None:\n            for pr in prefix.split('.')[::-1]:\n                config_candidate={pr: config_candidate}\n        self.configurations.insert(self.last_default_configuration_position, self._create_config_dict(config_candidate, None))\n        self.last_default_configuration_position += 1\n\n    @optional_kwargs_decorator\n    def command(\n        self, function=None, prefix=None, unobserved=False, add_default_args_config=True, static_args={}, **extra_args\n    ):\n        \"\"\"\n        Decorator to define a new Command.\n\n        a command is a function whose parameters are filled automatically by sacred.\n\n        The command can be given a prefix, to restrict its configuration space\n        to a subtree. (see ``capture`` for more information)\n\n        A command can be made unobserved (i.e. ignoring all observers) by\n        passing the unobserved=True keyword argument.\n        :param function: the function to return a Dataset Object\n        :param prefix: sacred configuration prefix\n        :param unobserved: sacred unobserved\n        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and\n         are not stored in the config\n        :param extra_args: explicit arguments to be add to the config, you can these to override the function default\n        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command\n        specified by CMD string value. CMD string have special context\n        :return:\n\n\n        \"\"\"\n        if add_default_args_config:\n            self.add_default_args_config(function, prefix, extra_args, static_args=static_args)\n        captured_f = self.capture(function, prefix=prefix, static_args=static_args)\n        captured_f.unobserved = unobserved\n        self.commands[function.__name__] = captured_f\n        return captured_f\n"
  },
  {
    "path": "ba3l/ingredients/models.py",
    "content": "import inspect\nimport os\n\nfrom .ingredient import Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.utils import PathType\n\nimport inspect\nimport os\nfrom functools import partial\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom sacred.config import CMD\nfrom .ingredient import Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.utils import PathType, optional_kwargs_decorator\nfrom munch import DefaultFactoryMunch, Munch\n\n\nclass Models(Ingredient):\n    \"\"\"\n    The class that annotates the models of Ba3l experiment\n\n\n    \"\"\"\n\n    __instance = None\n\n    @classmethod\n    def get_instance(cls):\n        if Models.__instance is None:\n            Models.__instance = Models()\n        return Models.__instance\n\n    def __init__(\n        self,\n        name: Optional[str] = None,\n        ingredients: Sequence[Ingredient] = (),\n        interactive: bool = False,\n        base_dir: Optional[PathType] = None,\n        save_git_info: bool = True,\n    ):\n        \"\"\"\n        Create a new experiment with the given name and optional ingredients.\n\n        Parameters\n        ----------\n        name\n            Optional name of this experiment, defaults to the filename.\n            (Required in interactive mode)\n\n        ingredients : list[sacred.Ingredient], optional\n            A list of ingredients to be used with this experiment.\n\n        interactive\n            If set to True will allow the experiment to be run in interactive\n            mode (e.g. IPython or Jupyter notebooks).\n            However, this mode is discouraged since it won't allow storing the\n            source-code or reliable reproduction of the runs.\n\n        base_dir\n            Optional full path to the base directory of this experiment. This\n            will set the scope for automatic source file discovery.\n\n        additional_host_info\n            Optional dictionary containing as keys the names of the pieces of\n            host info you want to collect, and as\n            values the functions collecting those pieces of information.\n\n        save_git_info:\n            Optionally save the git commit hash and the git state\n            (clean or dirty) for all source files. This requires the GitPython\n            package.\n        \"\"\"\n\n        caller_globals = inspect.stack()[1][0].f_globals\n        if name is None:\n            name = \"models\"\n\n        super().__init__(\n            path=name,\n            ingredients=ingredients,\n            interactive=interactive,\n            base_dir=base_dir,\n            _caller_globals=caller_globals,\n            save_git_info=save_git_info,\n        )\n\n        self.get_models_command = None\n        self.current_run = None\n        # self.command(print_config, unobserved=True)\n\n\ndef raise_(ex):\n    raise ex\n\n\nclass Model(Ingredient):\n    \"\"\"\n    The class that annotates a Dateset of Ba3l experiment\n    a Dataset can be\n\n\n    \"\"\"\n\n    MODEL_STRING_PREFIX = \"get_instance\"\n\n    def __init__(\n        self,\n        name: str,\n        ingredients: Sequence[Ingredient] = (),\n        interactive: bool = False,\n        base_dir: Optional[PathType] = None,\n        save_git_info: bool = True,\n    ):\n        \"\"\"\n        Create a new experiment with the given name and optional ingredients.\n\n        Parameters\n        ----------\n        name\n            Optional name of this experiment, defaults to the filename.\n            (Required in interactive mode)\n\n        ingredients : list[sacred.Ingredient], optional\n            A list of ingredients to be used with this experiment.\n\n        interactive\n            If set to True will allow the experiment to be run in interactive\n            mode (e.g. IPython or Jupyter notebooks).\n            However, this mode is discouraged since it won't allow storing the\n            source-code or reliable reproduction of the runs.\n\n        base_dir\n            Optional full path to the base directory of this experiment. This\n            will set the scope for automatic source file discovery.\n\n        additional_host_info\n            Optional dictionary containing as keys the names of the pieces of\n            host info you want to collect, and as\n            values the functions collecting those pieces of information.\n\n        save_git_info:\n            Optionally save the git commit hash and the git state\n            (clean or dirty) for all source files. This requires the GitPython\n            package.\n        \"\"\"\n\n        caller_globals = inspect.stack()[1][0].f_globals\n        if name is None:\n            name = \"model\"\n        self.name = name.rsplit(\".\", 1)[-1]\n        super().__init__(\n            path=name,\n            ingredients=ingredients,\n            interactive=interactive,\n            base_dir=base_dir,\n            _caller_globals=caller_globals,\n            save_git_info=save_git_info,\n        )\n\n        self.get_instance_command = None\n        self.current_run = None\n        self.get_instance = lambda: raise_(\n            NotImplementedError(\n                \"Use dataset.dataset_name.dataset to annotate the  \"\n                \"get_dataset function!.\"\n            )\n        )\n\n    @optional_kwargs_decorator\n    def instance(\n        self, function=None, prefix=None, unobserved=False, static_args={}, **extra_args\n    ):\n        \"\"\"\n        Decorator to define a new model.\n\n        The name of the model is used to get an instance of the model, it will register a command\n\n\n        The command can be given a prefix, to restrict its configuration space\n        to a subtree. (see ``capture`` for more information)\n\n        A command can be made unobserved (i.e. ignoring all observers) by\n        passing the unobserved=True keyword argument.\n        :param function: the function to return a Dataset Object\n        :param prefix: sacred configuration prefix\n        :param unobserved: sacred unobserved\n        :param static_args: static Args to be passed to the function, these arg need not to be serlizable and\n         are not stored in the config\n        :param extra_args: explicit arguments to be add to the config, you can these to override the function default\n        values, for example wraping a config with CMD, then the parameter will be filled with excuting the command\n        specified by CMD string value. CMD string have special context\n        :return:\n\n\n        \"\"\"\n        self.add_default_args_config(function, prefix, extra_args, static_args=static_args)\n        captured_f = self.capture(function, prefix=prefix, static_args=static_args)\n        captured_f.unobserved = unobserved\n        self.commands[Model.MODEL_STRING_PREFIX] = captured_f\n        self.get_instance = captured_f\n        self.add_config(get_instance=CMD(\"get_instance\"))\n        return captured_f\n\n    def __getattr__(self, k):\n        if k == \"get_instance\":\n            return self.__getattribute__(\"get_instance\")\n        super().__getattribute__(k)\n        # @todo maybe run commands from here after running\n"
  },
  {
    "path": "ba3l/ingredients/trainer.py",
    "content": "import inspect\nimport os\n\nfrom ba3l.util.functions import get_default_kwargs_dict\nfrom .datasets import Datasets, raise_\nfrom .models import Models\nfrom sacred import Ingredient\n\nfrom typing import Sequence, Optional, List\n\nfrom sacred.utils import PathType, optional_kwargs_decorator\n\n\nclass Trainer(Ingredient):\n    \"\"\"\n    The class that annotates the main Trainer of Ba3l experiment\n\n\n    \"\"\"\n\n    TRAINER_STRING_PREFIX = \"get_trainer\"\n"
  },
  {
    "path": "ba3l/module.py",
    "content": "import pytorch_lightning as pl\n\nimport warnings\nfrom abc import ABC\n\nimport torch.distributed as dist\nfrom munch import DefaultMunch\n\ntry:\n    # loading for pyTorch 1.3\n    from torch.utils.data import IterableDataset\nexcept ImportError:\n    # loading for pyTorch 1.1\n    import torch\n\n    warnings.warn(\n        \"Your version of pyTorch %s does not support `IterableDataset`,\"\n        \" please upgrade to 1.2+\" % torch.__version__,\n        ImportWarning,\n    )\n    EXIST_ITER_DATASET = False\nelse:\n    EXIST_ITER_DATASET = True\n\ntry:\n    from apex import amp\n\n    APEX_AVAILABLE = True\nexcept ImportError:\n    APEX_AVAILABLE = False\n\n\nclass Ba3lModule(pl.LightningModule):\n    def __init__(self, experiment):\n        super(Ba3lModule, self).__init__()\n        self.experiment = experiment\n        self.run =  experiment.current_run\n        self.config = DefaultMunch.fromDict(experiment.current_run.config)\n        for key,model in experiment.current_run.config['models'].items():\n            setattr(self, key, experiment.current_run.get_command_function(\"models.\"+key+\".\"+model['instance_cmd'])())\n        self.save_hyperparameters(self.config)\n        \n\n"
  },
  {
    "path": "ba3l/plutils/__init__.py",
    "content": ""
  },
  {
    "path": "ba3l/plutils/lr_monitor.py",
    "content": "# Copyright The PyTorch Lightning team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nr\"\"\"\n\nLearning Rate Monitor\n=====================\n\nMonitor and logs learning rate for lr schedulers during training.\n\n\"\"\"\n\nfrom typing import Dict, List, Optional\n\nfrom pytorch_lightning.callbacks.base import Callback\nfrom pytorch_lightning.utilities import rank_zero_warn\nfrom pytorch_lightning.utilities.exceptions import MisconfigurationException\n\n\nclass LearningRateMonitor(Callback):\n    r\"\"\"\n    Automatically monitor and logs learning rate for learning rate schedulers during training.\n\n    Args:\n        logging_interval: set to ``'epoch'`` or ``'step'`` to log ``lr`` of all optimizers\n            at the same interval, set to ``None`` to log at individual interval\n            according to the ``interval`` key of each scheduler. Defaults to ``None``.\n        log_momentum: option to also log the momentum values of the optimizer, if the optimizer\n            has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.\n\n    Raises:\n        MisconfigurationException:\n            If ``logging_interval`` is none of ``\"step\"``, ``\"epoch\"``, or ``None``.\n\n    Example::\n\n        >>> from pytorch_lightning import Trainer\n        >>> from pytorch_lightning.callbacks import LearningRateMonitor\n        >>> lr_monitor = LearningRateMonitor(logging_interval='step')\n        >>> trainer = Trainer(callbacks=[lr_monitor])\n\n    Logging names are automatically determined based on optimizer class name.\n    In case of multiple optimizers of same type, they will be named ``Adam``,\n    ``Adam-1`` etc. If a optimizer has multiple parameter groups they will\n    be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a\n    ``name`` keyword in the construction of the learning rate schdulers\n\n    Example::\n\n        def configure_optimizer(self):\n            optimizer = torch.optim.Adam(...)\n            lr_scheduler = {\n                'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, ...)\n                'name': 'my_logging_name'\n            }\n            return [optimizer], [lr_scheduler]\n\n    \"\"\"\n\n    def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):\n        if logging_interval not in (None, 'step', 'epoch'):\n            raise MisconfigurationException('logging_interval should be `step` or `epoch` or `None`.')\n\n        self.logging_interval = logging_interval\n        self.log_momentum = log_momentum\n        self.lrs = None\n        self.lr_sch_names = []\n\n    def on_train_start(self, trainer, *args, **kwargs):\n        \"\"\"\n        Called before training, determines unique names for all lr\n        schedulers in the case of multiple of the same type or in\n        the case of multiple parameter groups\n\n        Raises:\n            MisconfigurationException:\n                If ``Trainer`` has no ``logger``.\n        \"\"\"\n        if not trainer.logger:\n            raise MisconfigurationException(\n                'Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger.'\n            )\n\n        if not trainer.lr_schedulers:\n            rank_zero_warn(\n                'You are using `LearningRateMonitor` callback with models that'\n                ' have no learning rate schedulers. Please see documentation'\n                ' for `configure_optimizers` method.', RuntimeWarning\n            )\n\n        if self.log_momentum:\n\n            def _check_no_key(key):\n                return any(key not in sch['scheduler'].optimizer.defaults for sch in trainer.lr_schedulers)\n\n            if _check_no_key('momentum') and _check_no_key('betas'):\n                rank_zero_warn(\n                    \"You have set log_momentum=True, but some optimizers do not\"\n                    \" have momentum. This will log a value 0 for the momentum.\", RuntimeWarning\n                )\n\n        # Find names for schedulers\n        names = self._find_names(trainer.lr_schedulers)\n\n        # Initialize for storing values\n        self.lrs = {name: [] for name in names}\n        self.last_momentum_values = {name + \"-momentum\": None for name in names}\n\n    def on_train_batch_start(self, trainer, *args, **kwargs):\n        if not self._should_log(trainer):\n            return\n\n        if self.logging_interval != 'epoch':\n            interval = 'step' if self.logging_interval is None else 'any'\n            latest_stat = self._extract_stats(trainer, interval)\n\n            if latest_stat:\n                trainer.logger.log_metrics(latest_stat, step=trainer.global_step)\n\n    def on_train_epoch_start(self, trainer, *args, **kwargs):\n        if self.logging_interval != 'step':\n            interval = 'epoch' if self.logging_interval is None else 'any'\n            latest_stat = self._extract_stats(trainer, interval)\n\n            if latest_stat:\n                trainer.logger.log_metrics(latest_stat, step=trainer.current_epoch)\n\n    def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:\n        latest_stat = {}\n\n        for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):\n            if scheduler['interval'] == interval or interval == 'any':\n                opt = scheduler['scheduler'].optimizer\n                param_groups = opt.param_groups\n                use_betas = 'betas' in opt.defaults\n\n                for i, pg in enumerate(param_groups):\n                    suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''\n                    lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')\n                    latest_stat.update(lr)\n                    momentum = self._extract_momentum(\n                        param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas\n                    )\n                    latest_stat.update(momentum)\n\n        return latest_stat\n\n    def _extract_lr(self, param_group, name: str) -> Dict[str, float]:\n        lr = param_group.get('lr')\n        self.lrs[name].append(lr)\n        return {name: lr}\n\n    def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:\n        if not self.log_momentum:\n            return {}\n\n        momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)\n        self.last_momentum_values[name] = momentum\n        return {name: momentum}\n\n    def _find_names(self, lr_schedulers) -> List[str]:\n        # Create uniqe names in the case we have multiple of the same learning\n        # rate schduler + multiple parameter groups\n        names = []\n        for scheduler in lr_schedulers:\n            sch = scheduler['scheduler']\n            if scheduler['name'] is not None:\n                name = scheduler['name']\n            else:\n                opt_name = 'lr-' + sch.optimizer.__class__.__name__\n                i, name = 1, opt_name\n\n                # Multiple schduler of the same type\n                while True:\n                    if name not in names:\n                        break\n                    i, name = i + 1, f'{opt_name}-{i}'\n\n            # Multiple param groups for the same schduler\n            param_groups = sch.optimizer.param_groups\n\n            if len(param_groups) != 1:\n                for i, pg in enumerate(param_groups):\n                    temp = f'{name}/pg{i + 1}'\n                    names.append(temp)\n            else:\n                names.append(name)\n\n            self.lr_sch_names.append(name)\n\n        return names\n\n    @staticmethod\n    def _should_log(trainer) -> bool:\n        should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)\n\n        return should_log\n"
  },
  {
    "path": "ba3l/plutils/progress_bar.py",
    "content": "\nimport importlib\nimport io\nimport os\nimport sys\n\n# check if ipywidgets is installed before importing tqdm.auto\n# to ensure it won't fail and a progress bar is displayed\nfrom typing import Optional, Union\n\n\nfrom pytorch_lightning.callbacks import Callback,ProgressBarBase , ProgressBar as PlProgressBar\nfrom pytorch_lightning.callbacks.progress import tqdm\n\n\n\nclass ProgressBar(PlProgressBar):\n    r\"\"\"\n    This is the default progress bar used by Lightning. It prints to `stdout` using the\n    :mod:`tqdm` package and shows up to four different bars:\n\n    - **sanity check progress:** the progress during the sanity check run\n    - **main progress:** shows training + validation progress combined. It also accounts for\n      multiple validation runs during training when\n      :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.\n    - **validation progress:** only visible during validation;\n      shows total progress over all validation datasets.\n    - **test progress:** only active when testing; shows total progress over all test datasets.\n\n    For infinite datasets, the progress bar never ends.\n\n    If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override\n    specific methods of the callback class and pass your custom implementation to the\n    :class:`~pytorch_lightning.trainer.trainer.Trainer`:\n\n    Example::\n\n        class LitProgressBar(ProgressBar):\n\n            def init_validation_tqdm(self):\n                bar = super().init_validation_tqdm()\n                bar.set_description('running validation ...')\n                return bar\n\n        bar = LitProgressBar()\n        trainer = Trainer(callbacks=[bar])\n\n    Args:\n        refresh_rate:\n            Determines at which rate (in number of batches) the progress bars get updated.\n            Set it to ``0`` to disable the display. By default, the\n            :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress\n            bar and sets the refresh rate to the value provided to the\n            :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the\n            :class:`~pytorch_lightning.trainer.trainer.Trainer`.\n        process_position:\n            Set this to a value greater than ``0`` to offset the progress bars by this many lines.\n            This is useful when you have progress bars defined elsewhere and want to show all of them\n            together. This corresponds to\n            :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the\n            :class:`~pytorch_lightning.trainer.trainer.Trainer`.\n\n    \"\"\"\n\n\n    def init_validation_tqdm(self) -> tqdm:\n        \"\"\" Override this to customize the tqdm bar for validation. \"\"\"\n        # The main progress bar doesn't exist in `trainer.validate()`\n        has_main_bar = self.main_progress_bar is not None\n        has_main_bar = False\n        bar = tqdm(\n            desc='Validating',\n            position=(2 * self.process_position + has_main_bar),\n            disable=self.is_disabled,\n            leave=False,\n            dynamic_ncols=True,\n            file=sys.stdout\n        )\n        return bar\n\n\n\n    def on_epoch_start(self, trainer, pl_module):\n        ProgressBarBase.on_epoch_start(self, trainer, pl_module)\n        self.main_progress_bar = self.init_train_tqdm()\n        total_train_batches = self.total_train_batches\n        total_val_batches = self.total_val_batches\n        if total_train_batches != float('inf'):\n            # val can be checked multiple times per epoch\n            val_checks_per_epoch = total_train_batches // trainer.val_check_batch\n            total_val_batches = total_val_batches * val_checks_per_epoch\n        total_val_batches = 0\n        total_batches = total_train_batches + total_val_batches\n        reset(self.main_progress_bar, total_batches)\n        self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')\n\n    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n        ProgressBarBase.on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)\n        if self._should_update(self.train_batch_idx, self.total_train_batches):\n            self._update_bar(self.main_progress_bar)\n            self.main_progress_bar.set_postfix(trainer.progress_bar_dict)\n\n    def on_validation_start(self, trainer, pl_module):\n        ProgressBarBase.on_validation_start(self, trainer, pl_module)\n        if trainer.sanity_checking:\n            reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))\n        else:\n            if self.main_progress_bar is not None:\n                self.main_progress_bar.close()\n            self.val_progress_bar = self.init_validation_tqdm()\n            reset(self.val_progress_bar, self.total_val_batches)\n\n    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n        ProgressBarBase.on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)\n        if self._should_update(self.val_batch_idx, self.total_val_batches):\n            self._update_bar(self.val_progress_bar)\n            #self._update_bar(self.main_progress_bar)\n\n    def on_validation_end(self, trainer, pl_module):\n        ProgressBarBase.on_validation_end(self, trainer, pl_module)\n        if self.main_progress_bar is not None:\n            self.main_progress_bar.set_postfix(trainer.progress_bar_dict)\n        self.val_progress_bar.close()\n\n\n\ndef convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:\n    \"\"\" The tqdm doesn't support inf values. We have to convert it to None. \"\"\"\n    if x == float('inf'):\n        return None\n    return x\n\n\ndef reset(bar: tqdm, total: Optional[int] = None) -> None:\n    \"\"\" Resets the tqdm bar to 0 progress with a new total, unless it is disabled. \"\"\"\n    if not bar.disable:\n        bar.reset(total=convert_inf(total))\n"
  },
  {
    "path": "ba3l/util/__init__.py",
    "content": ""
  },
  {
    "path": "ba3l/util/functions.py",
    "content": "import inspect\nfrom collections import OrderedDict\n\n\ndef get_default_kwargs_dict(f):\n    sig = inspect.signature(f)\n    return OrderedDict(\n        [\n            (p.name, p.default)\n            for p in sig.parameters.values()\n            if p.default != inspect._empty\n        ]\n    )\n"
  },
  {
    "path": "ba3l/util/sacred_logger.py",
    "content": "from pytorch_lightning.utilities import rank_zero_only\ntry:\n    from pytorch_lightning.loggers import Logger as LightningLoggerBase\n    from pytorch_lightning.loggers.logger import rank_zero_experiment\nexcept ImportError:\n    from pytorch_lightning.loggers import LightningLoggerBase\n    from pytorch_lightning.loggers.base import rank_zero_experiment\n\nfrom warnings import warn\n\n\nfrom logging import getLogger\n\ntry:\n    import sacred\nexcept ImportError:\n    raise ImportError(\"Missing sacred package.  Run `pip install sacred`\")\n\n\nlogger = getLogger(__name__)\n\n\nclass SacredLogger(LightningLoggerBase):\n    def __init__(self, sacred_experiment):\n        \"\"\"Initialize a sacred logger.\n        :param sacred.experiment.Experiment sacred_experiment: Required. Experiment object with desired observers\n        already appended.\n        source: https://github.com/expectopatronum/pytorch-lightning/blob/9fcb238ec03e3f0b0378fd058119f1563a11650c/pytorch_lightning/logging/sacred.py\n        \"\"\"\n        super().__init__()\n        self.sacred_experiment = sacred_experiment\n        self.experiment_name = sacred_experiment.path\n        self._run_id = None\n        warn('SacredLogger is deprecated', DeprecationWarning, stacklevel=2)\n\n    @property\n    def experiment(self):\n        return self.sacred_experiment\n\n    @property\n    def run_id(self):\n        if self._run_id is not None:\n            return self._run_id\n        self._run_id = self.sacred_experiment.get_run_identifier()\n        return self._run_id\n\n    @rank_zero_only\n    def log_hyperparams(self, params):\n        # probably not needed bc. it is dealt with by sacred\n        pass\n\n    @rank_zero_only\n    def log_metrics(self, metrics, step=None):\n        for k, v in metrics.items():\n            if isinstance(v, str):\n                logger.warning(f\"Discarding metric with string value {k}={v}\")\n                continue\n            self.experiment.log_scalar(k, v, step)\n\n    @property\n    def name(self):\n        return self.experiment_name\n\n    @property\n    def version(self):\n        return self.run_id\n\n    @rank_zero_only\n    def save(self):\n        # Optional. Any code necessary to save logger data goes here\n        # If you implement this, remember to call `super().save()`\n        # at the start of the method (important for aggregation of metrics)\n        super().save()\n\n    @rank_zero_only\n    def finalize(self, status):\n        # Optional. Any code that needs to be run after training\n        # finishes goes here\n        pass\n"
  },
  {
    "path": "config_updates.py",
    "content": "from sacred.config_helpers import DynamicIngredient, CMD\n\n\ndef add_configs(ex):\n    '''\n    This functions add generic configuration for the experiments, such as mix-up, architectures, etc...\n    @param ex: Ba3l Experiment\n    @return:\n    '''\n\n    @ex.named_config\n    def nomixup():\n        'Don\\'t apply mix-up (spectrogram level).'\n        use_mixup = False\n        mixup_alpha = 0.3\n\n    @ex.named_config\n    def mixup():\n        ' Apply mix-up (spectrogram level).'\n        use_mixup = True\n        mixup_alpha = 0.3\n\n    @ex.named_config\n    def mini_train():\n        'limit training/validation to 5 batches for debbuging.'\n        trainer = dict(limit_train_batches=5, limit_val_batches=5)\n\n    @ex.named_config\n    def passt():\n        'use PaSST model'\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\")\n        }\n\n    @ex.named_config\n    def passt_s_20sec():\n        'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 20 seconds'\n        # python ex_audioset.py evaluate_only with passt_s_ap476\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_f128_20sec_p16_s10_ap474\", fstride=10,\n                                     tstride=10, input_tdim=2000)\n        }\n        basedataset = dict(clip_length=20)\n\n    @ex.named_config\n    def passt_s_30sec():\n        'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 30 seconds'\n        # python ex_audioset.py evaluate_only with passt_s_ap476\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_f128_30sec_p16_s10_ap473\", fstride=10,\n                                     tstride=10, input_tdim=3000)\n        }\n        basedataset = dict(clip_length=20)\n\n    @ex.named_config\n    def passt_s_ap476():\n        'use PaSST model pretrained on Audioset (with SWA) ap=476'\n        # python ex_audioset.py evaluate_only with passt_s_ap476\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_swa_p16_128_ap476\", fstride=10,\n                                     tstride=10)\n        }\n\n    @ex.named_config\n    def passt_s_ap4763():\n        'use PaSST model pretrained on Audioset (with SWA) ap=4763'\n        # test with: python ex_audioset.py evaluate_only with passt_s_ap4763\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_swa_p16_128_ap4763\", fstride=10,\n                                     tstride=10)\n        }\n\n    @ex.named_config\n    def passt_s_ap472():\n        'use PaSST model pretrained on Audioset (no SWA) ap=472'\n        # test with: python ex_audioset.py evaluate_only with passt_s_ap472\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_p16_128_ap472\", fstride=10,\n                                     tstride=10)\n        }\n\n    @ex.named_config\n    def passt_s_p16_s16_128_ap468():\n        'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap'\n        # test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_p16_s16_128_ap468\", fstride=16,\n                                     tstride=16)\n        }\n\n    @ex.named_config\n    def passt_s_swa_p16_s16_128_ap473():\n        'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap'\n        # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_swa_p16_s16_128_ap473\", fstride=16,\n                                     tstride=16)\n        }\n\n    @ex.named_config\n    def passt_s_swa_p16_s14_128_ap471():\n        'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 '\n        # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_swa_p16_s14_128_ap471\", fstride=14,\n                                     tstride=14)\n        }\n\n    @ex.named_config\n    def passt_s_p16_s14_128_ap469():\n        'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 '\n        # test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_p16_s14_128_ap469\", fstride=14,\n                                     tstride=14)\n        }\n\n    @ex.named_config\n    def passt_s_swa_p16_s12_128_ap473():\n        'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 '\n        # test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_swa_p16_s12_128_ap473\", fstride=12,\n                                     tstride=12)\n        }\n\n    @ex.named_config\n    def passt_s_p16_s12_128_ap470():\n        'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 '\n        # test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_s_p16_s12_128_ap470\", fstride=12,\n                                     tstride=12)\n        }\n\n    @ex.named_config\n    def ensemble_s10():\n        'use ensemble of PaSST models pretrained on Audioset  with S10 mAP=.4864'\n        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s10\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"ensemble_s10\", fstride=None,\n                                     tstride=None, instance_cmd=\"get_ensemble_model\",\n                                     # don't call get_model but rather get_ensemble_model\n                                     arch_list=[\n                                         (\"passt_s_swa_p16_128_ap476\", 10, 10),\n                                         (\"passt_s_swa_p16_128_ap4761\", 10, 10),\n                                         (\"passt_s_p16_128_ap472\", 10, 10),\n                                     ]\n                                     )\n        }\n\n    @ex.named_config\n    def ensemble_many():\n        'use ensemble of PaSST models pretrained on Audioset  with different strides mAP=.4956'\n        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"ensemble_many\", fstride=None,\n                                     tstride=None, instance_cmd=\"get_ensemble_model\",\n                                     # don't call get_model but rather get_ensemble_model\n                                     arch_list=[\n                                         (\"passt_s_swa_p16_128_ap476\", 10, 10),\n                                         (\"passt_s_swa_p16_128_ap4761\", 10, 10),\n                                         (\"passt_s_p16_128_ap472\", 10, 10),\n                                         (\"passt_s_p16_s12_128_ap470\", 12, 12),\n                                         (\"passt_s_swa_p16_s12_128_ap473\", 12, 12),\n                                         (\"passt_s_p16_s14_128_ap469\", 14, 14),\n                                         (\"passt_s_swa_p16_s14_128_ap471\", 14, 14),\n                                         (\"passt_s_swa_p16_s16_128_ap473\", 16, 16),\n                                         (\"passt_s_p16_s16_128_ap468\", 16, 16),\n                                     ]\n                                     )\n        }\n\n    @ex.named_config\n    def ensemble_4():\n        'use ensemble of PaSST models pretrained on Audioset  with different strides mAP=.4926'\n        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"ensemble_many\", fstride=None,\n                                     tstride=None, instance_cmd=\"get_ensemble_model\",\n                                     # don't call get_model but rather get_ensemble_model\n                                     arch_list=[\n                                         (\"passt_s_swa_p16_128_ap476\", 10, 10),\n                                         (\"passt_s_swa_p16_s12_128_ap473\", 12, 12),\n                                         (\"passt_s_swa_p16_s14_128_ap471\", 14, 14),\n                                         (\"passt_s_swa_p16_s16_128_ap473\", 16, 16),\n                                     ]\n                                     )\n        }\n\n    @ex.named_config\n    def ensemble_5():\n        'use ensemble of PaSST models pretrained on Audioset  with different strides mAP=.49459'\n        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_many\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"ensemble_many\", fstride=None,\n                                     tstride=None, instance_cmd=\"get_ensemble_model\",\n                                     # don't call get_model but rather get_ensemble_model\n                                     arch_list=[\n                                         (\"passt_s_swa_p16_128_ap476\", 10, 10),\n                                         (\"passt_s_swa_p16_128_ap4761\", 10, 10),\n                                         (\"passt_s_swa_p16_s12_128_ap473\", 12, 12),\n                                         (\"passt_s_swa_p16_s14_128_ap471\", 14, 14),\n                                         (\"passt_s_swa_p16_s16_128_ap473\", 16, 16),\n                                     ]\n                                     )\n        }\n\n    @ex.named_config\n    def ensemble_s16_14():\n        'use ensemble of two PaSST models pretrained on Audioset  with stride 16 and 14 mAP=.48579'\n        # test with: python ex_audioset.py evaluate_only with  trainer.precision=16 ensemble_s16_14\n        models = {\n            \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"ensemble_s16\", fstride=None,\n                                     tstride=None, instance_cmd=\"get_ensemble_model\",\n                                     # don't call get_model but rather get_ensemble_model\n                                     arch_list=[\n                                         (\"passt_s_swa_p16_s14_128_ap471\", 14, 14),\n                                         (\"passt_s_swa_p16_s16_128_ap473\", 16, 16),\n                                     ]\n                                     )\n        }\n\n    @ex.named_config\n    def dynamic_roll():\n        # dynamically roll the spectrograms/waveforms\n        # updates the dataset config\n        basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000)\n                           )\n\n    # extra commands\n\n    @ex.command\n    def test_loaders_train_speed():\n        # test how fast data is being loaded from the data loaders.\n        itr = ex.datasets.training.get_iter()\n        import time\n        start = time.time()\n        print(\"hello\")\n        for i, b in enumerate(itr):\n            if i % 20 == 0:\n                print(f\"{i}/{len(itr)}\", end=\"\\r\")\n        end = time.time()\n        print(\"totoal time:\", end - start)\n        start = time.time()\n        print(\"retry:\")\n        for i, b in enumerate(itr):\n            if i % 20 == 0:\n                print(f\"{i}/{len(itr)}\", end=\"\\r\")\n        end = time.time()\n        print(\"totoal time:\", end - start)\n"
  },
  {
    "path": "environment.yml",
    "content": "name: ba3l\nchannels:\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1\n  - _openmp_mutex=5.1\n  - ca-certificates=2022.10.11\n  - certifi=2022.12.7\n  - ld_impl_linux-64=2.38\n  - libffi=3.4.2\n  - libgcc-ng=11.2.0\n  - libgomp=11.2.0\n  - libstdcxx-ng=11.2.0\n  - ncurses=6.3\n  - openssl=1.1.1s\n  - pip=22.3.1\n  - python=3.8.16\n  - readline=8.2\n  - setuptools=65.6.3\n  - sqlite=3.40.1\n  - tk=8.6.12\n  - wheel=0.37.1\n  - xz=5.2.10\n  - zlib=1.2.13\n  - pip:\n    - absl-py==1.4.0\n    - aiohttp==3.8.4\n    - aiosignal==1.3.1\n    - appdirs==1.4.4\n    - async-timeout==4.0.2\n    - attrs==22.2.0\n    - audioread==3.0.0\n    - av==10.0.0\n    - cachetools==5.3.0\n    - cffi==1.15.1\n    - charset-normalizer==3.0.1\n    - colorama==0.4.6\n    - decorator==5.1.1\n    - docopt==0.6.2\n    - frozenlist==1.3.3\n    - fsspec==2023.3.0\n    - future==0.18.3\n    - gitdb==4.0.10\n    - gitpython==3.1.31\n    - google-auth==2.17.0\n    - google-auth-oauthlib==0.4.6\n    - grpcio==1.53.0\n    - h5py==3.8.0\n    - idna==3.4\n    - imageio==2.27.0\n    - importlib-metadata==6.1.0\n    - joblib==1.2.0\n    - jsonpickle==3.0.1\n    - kk-sacred==0.8.4\n    - lazy-loader==0.2\n    - librosa==0.10.0.post2\n    - llvmlite==0.39.1\n    - markdown==3.4.3\n    - markupsafe==2.1.2\n    - msgpack==1.0.5\n    - multidict==6.0.4\n    - munch==2.5.0\n    - numba==0.56.4\n    - numpy==1.23.5\n    - nvidia-cublas-cu11==11.10.3.66\n    - nvidia-cuda-nvrtc-cu11==11.7.99\n    - nvidia-cuda-runtime-cu11==11.7.99\n    - nvidia-cudnn-cu11==8.5.0.96\n    - oauthlib==3.2.2\n    - packaging==23.0\n    - pandas==1.5.3\n    - pillow==9.4.0\n    - pooch==1.6.0\n    - protobuf==4.22.1\n    - py-cpuinfo==9.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycparser==2.21\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.3.0.dev0\n    - pytz==2023.3\n    - pyyaml==6.0\n    - requests==2.28.2\n    - requests-oauthlib==1.3.1\n    - rsa==4.9\n    - scikit-learn==1.2.2\n    - scipy==1.10.1\n    - six==1.16.0\n    - smmap==5.0.0\n    - soundfile==0.12.1\n    - soxr==0.3.4\n    - tensorboard==2.12.0\n    - tensorboard-data-server==0.7.0\n    - tensorboard-plugin-wit==1.8.1\n    - test-tube==0.7.5\n    - threadpoolctl==3.1.0\n    - timm==0.4.12\n    - torch==1.13.1\n    - torchaudio==0.13.1\n    - torchmetrics==0.2.0\n    - torchvision==0.14.1\n    - tqdm==4.65.0\n    - typing-extensions==4.4.0\n    - urllib3==1.26.14\n    - werkzeug==2.2.3\n    - wrapt==1.15.0\n    - yarl==1.8.2\n    - zipp==3.15.0\n"
  },
  {
    "path": "environment_old_2021.yml",
    "content": "name: ba3l\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1\n  - _openmp_mutex=4.5\n  - _pytorch_select=0.1\n  - appdirs=1.4.4\n  - audioread=2.1.9\n  - blas=1.0\n  - brotlipy=0.7.0\n  - bzip2=1.0.8\n  - c-ares=1.17.1\n  - ca-certificates=2020.12.5\n  - cached-property=1.5.2\n  - cached_property=1.5.2\n  - certifi=2020.12.5\n  - cffi=1.14.5\n  - chardet=4.0.0\n  - colorama=0.4.4\n  - cryptography=3.4.6\n  - cycler=0.10.0\n  - decorator=4.4.2\n  - docopt=0.6.2\n  - ffmpeg=4.3.1\n  - freetype=2.10.4\n  - gettext=0.19.8.1\n  - gitdb=4.0.5\n  - gitpython=3.1.14\n  - gmp=6.2.1\n  - gnutls=3.6.13\n  - h5py=3.1.0\n  - hdf5=1.10.6\n  - idna=2.10\n  - importlib-metadata=3.7.3\n  - importlib_metadata=3.7.3\n  - intel-openmp=2020.2\n  - joblib=1.0.1\n  - jpeg=9d\n  - jsonpickle=1.4.1\n  - kiwisolver=1.3.1\n  - krb5=1.17.2\n  - lame=3.100\n  - lcms2=2.12\n  - ld_impl_linux-64=2.35.1\n  - libblas=3.9.0\n  - libcblas=3.9.0\n  - libcurl=7.75.0\n  - libedit=3.1.20191231\n  - libev=4.33\n  - libffi=3.3\n  - libflac=1.3.3\n  - libgcc-ng=9.3.0\n  - libgfortran-ng=9.3.0\n  - libgfortran5=9.3.0\n  - libgomp=9.3.0\n  - liblapack=3.9.0\n  - libllvm10=10.0.1\n  - libnghttp2=1.43.0\n  - libogg=1.3.4\n  - libopenblas=0.3.12\n  - libopus=1.3.1\n  - libpng=1.6.37\n  - librosa=0.8.0\n  - libsndfile=1.0.31\n  - libssh2=1.9.0\n  - libstdcxx-ng=9.3.0\n  - libtiff=4.2.0\n  - libvorbis=1.3.7\n  - libwebp-base=1.2.0\n  - llvm-openmp=11.1.0\n  - llvmlite=0.36.0\n  - lz4-c=1.9.3\n  - matplotlib-base=3.3.4\n  - mkl=2020.2\n  - mkl-service=2.3.0\n  - munch=2.5.0\n  - ncurses=6.2\n  - nettle=3.6\n  - ninja=1.10.2\n  - numba=0.53.0\n  - numpy=1.20.1\n  - olefile=0.46\n  - openblas=0.3.12\n  - openh264=2.1.1\n  - openssl=1.1.1k\n  - packaging=20.9\n  - pandas=1.2.3\n  - pillow=8.1.2\n  - pip=21.0.1\n  - pooch=1.3.0\n  - py-cpuinfo=7.0.0\n  - pycparser=2.20\n  - pyopenssl=20.0.1\n  - pyparsing=2.4.7\n  - pysocks=1.7.1\n  - pysoundfile=0.10.3.post1\n  - python=3.7.10\n  - python-dateutil=2.8.1\n  - python_abi=3.7\n  - pytz=2021.1\n  - readline=8.0\n  - requests=2.25.1\n  - resampy=0.2.2\n  - scikit-learn=0.24.1\n  - scipy=1.6.1\n  - setuptools=49.6.0\n  - six=1.15.0\n  - smmap=3.0.5\n  - sqlite=3.34.0\n  - threadpoolctl=2.1.0\n  - tk=8.6.10\n  - tornado=6.1\n  - typing_extensions=3.7.4.3\n  - urllib3=1.26.4\n  - wrapt=1.12.1\n  - x264=1!161.3030\n  - xz=5.2.5\n  - zipp=3.4.1\n  - zlib=1.2.11\n  - zstd=1.4.9\n  - pip:\n    - absl-py==0.12.0\n    - aiohttp==3.7.4.post0\n    - async-timeout==3.0.1\n    - attrs==20.3.0\n    - av==8.0.3\n    - black==20.8b1\n    - cachetools==4.2.1\n    - click==7.1.2\n    - einops==0.3.0\n    - fsspec==0.8.7\n    - future==0.18.2\n    - google-auth==1.28.0\n    - google-auth-oauthlib==0.4.3\n    - gpuinfo==1.0.0a7\n    - grpcio==1.36.1\n    - imageio==2.9.0\n    - jedi==0.18.0\n    - libtmux==0.8.5\n    - markdown==3.3.4\n    - multidict==5.1.0\n    - mypy-extensions==0.4.3\n    - oauthlib==3.1.0\n    - parso==0.8.2\n    - pathspec==0.8.1\n    - prompt-toolkit==3.0.18\n    - protobuf==3.15.6\n    - ptpython==3.0.17\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pydub==0.25.1\n    - pygments==2.8.1\n    - pymongo==3.11.3\n    - pyyaml==5.3.1\n    - regex==2021.4.4\n    - requests-oauthlib==1.3.0\n    - rsa==4.7.2\n    - tensorboard==2.4.1\n    - tensorboard-plugin-wit==1.8.0\n    - test-tube==0.7.5\n    - timm==0.4.12\n    - toml==0.10.2\n    - torch==1.8.1+cu111\n    - torchaudio==0.8.1\n    - torchmetrics==0.2.0\n    - torchvision==0.6.0\n    - tqdm==4.59.0\n    - typed-ast==1.4.3\n    - wcwidth==0.2.5\n    - werkzeug==1.0.1\n    - wheel==0.36.2\n    - yarl==1.6.3\n\n"
  },
  {
    "path": "esc50/README.md",
    "content": "# Experiments on ESC-50 Environmental Sound Classification\n[ESC-50](https://github.com/karolpiczak/ESC-50) consist of 2000 5-second recordings to be classified to 50 semantical classes.\n\n## Setting up the fine-tuning experiments\n- Download the prerpocessed (resampled) dataset [esc50.zip](https://github.com/kkoutini/PaSST/releases/download/v.0.0.6/esc50.zip) from the [releases page](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6) and \nunpack the zip file into a directory (the default path is `./audioset_hdf5s/`). The `base_dir` config in the dataset file ([here](https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py#L35)) should point to the extracted contents of the dataset zip file.\n- Running the experiments using the common configurations (similar to Audioset)\n```shell\npython3 ex_esc50.py with models.net.s_patchout_t=10 models.net.s_patchout_f=5  basedataset.fold=1 -p\n```\n## Pre-trained models\n\nPre-trained models on ESC-50 can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v.0.0.6). \n\nIn order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 :\n```python\nfrom hear21passt.base import get_basic_model,get_model_passt\nimport torch\n# model wrapper, includes Melspectrogram and the default transformer\nmodel = get_basic_model(mode=\"logits\")\n# replace the transformer with one that outputs 50 classes\nmodel.net = get_model_passt(arch=\"passt_s_swa_p16_128_ap476\",  n_classes=50)\n\n# load the pre-trained model state dict\nstate_dict = torch.load('/home/khaled/esc50-passt-s-n-f128-p16-s10-fold1-acc.967.pt')\n# load the weights into the transformer\nmodel.net.load_state_dict(state_dict)\n\n# example inference\nmodel.eval()\nmodel = model.cuda()\nwith torch.no_grad():\n    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k\n    logits=model(audio_wave) \n```\n\n"
  },
  {
    "path": "esc50/dataset.py",
    "content": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler\n\nimport torch\nfrom ba3l.ingredients.datasets import Dataset\nimport pandas as pd\nfrom sacred.config import DynamicIngredient, CMD\nfrom scipy.signal import convolve\nfrom sklearn import preprocessing\nfrom torch.utils.data import Dataset as TorchDataset\nimport numpy as np\nimport h5py\nfrom helpers.audiodatasets import  PreprocessDataset\n\n\nLMODE = os.environ.get(\"LMODE\", False)\n\ndataset = Dataset('Esc50')\n\n\n@dataset.config\ndef default_config():\n    name = 'esc50'  # dataset name\n    normalize = False  # normalize dataset\n    subsample = False  # subsample squares from the dataset\n    roll = True  # apply roll augmentation\n    fold = 1\n    base_dir = \"audioset_hdf5s/esc50/\"  # base directory of the dataset as downloaded\n    if LMODE:\n        base_dir = \"/system/user/publicdata/CP/audioset/audioset_hdf5s/esc50/\"\n    meta_csv = base_dir + \"meta/esc50.csv\"\n    audio_path = base_dir + \"audio_32k/\"\n    ir_path = base_dir + \"irs/\"\n    num_of_classes = 50\n\n\n\n\n\ndef decode_mp3(mp3_arr):\n    \"\"\"\n    decodes an array if uint8 representing an mp3 file\n    :rtype: np.array\n    \"\"\"\n    container = av.open(io.BytesIO(mp3_arr.tobytes()))\n    stream = next(s for s in container.streams if s.type == 'audio')\n    # print(stream)\n    a = []\n    for i, packet in enumerate(container.demux(stream)):\n        for frame in packet.decode():\n            a.append(frame.to_ndarray().reshape(-1))\n    waveform = np.concatenate(a)\n    if waveform.dtype != 'float32':\n        raise RuntimeError(\"Unexpected wave type\")\n    return waveform\n\n\ndef pad_or_truncate(x, audio_length):\n    \"\"\"Pad all audio to specific length.\"\"\"\n    if len(x) <= audio_length:\n        return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)\n    else:\n        return x[0: audio_length]\n\n\nirs_arr = None\n\n\n@dataset.command\ndef get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):\n    if not ir_augment:\n        return\n    global irs_arr\n    if irs_arr is None:\n        all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]\n        all_paths = sorted(all_paths)\n        if cut_irs_offset is not None:\n            all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]\n        all_paths_name = [str(p).rsplit(\"/\", 1)[-1] for p in all_paths]\n        print(\"will use these IRs:\")\n        for i in range(len(all_paths_name)):\n            print(i, \": \", all_paths_name[i])\n        _run.info[\"ir_devices\"] = all_paths_name\n        irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]\n    return irs_arr[int(np.random.randint(0, len(irs_arr)))]\n\n\n@dataset.command\ndef pydub_augment(waveform, gain_augment=7, ir_augment=0):\n    if ir_augment and torch.rand(1) < ir_augment:\n        ir = get_ir_sample()\n        waveform = convolve(waveform, ir, 'full')\n    if gain_augment:\n        gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment\n        amp = 10 ** (gain / 20)\n        waveform = waveform * amp\n    return waveform\n\n\nclass MixupDataset(TorchDataset):\n    \"\"\" Mixing Up wave forms\n    \"\"\"\n\n    def __init__(self, dataset, beta=2, rate=0.5):\n        self.beta = beta\n        self.rate = rate\n        self.dataset = dataset\n        print(f\"Mixing up waveforms from dataset of len {len(dataset)}\")\n\n    def __getitem__(self, index):\n        if torch.rand(1) < self.rate:\n            x1, f1, y1 = self.dataset[index]\n            idx2 = torch.randint(len(self.dataset), (1,)).item()\n            x2, f2, y2 = self.dataset[idx2]\n            l = np.random.beta(self.beta, self.beta)\n            l = max(l, 1. - l)\n            x1 = x1 - x1.mean()\n            x2 = x2 - x2.mean()\n            x = (x1 * l + x2 * (1. - l))\n            x = x - x.mean()\n            return x, f1, (y1 * l + y2 * (1. - l))\n        return self.dataset[index]\n\n    def __len__(self):\n        return len(self.dataset)\n\n\n\n\nclass AudioSetDataset(TorchDataset):\n    def __init__(self, meta_csv,  audiopath, fold, train=False, sample_rate=32000, classes_num=527,\n                 clip_length=5, augment=False):\n        \"\"\"\n        Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav\n        \"\"\"\n        self.sample_rate = sample_rate\n        self.meta_csv = meta_csv\n        self.df = pd.read_csv(meta_csv)\n        if train:  # training all except this\n            print(f\"Dataset training fold {fold} selection out of {len(self.df)}\")\n            self.df = self.df[self.df.fold != fold]\n            print(f\" for training remains {len(self.df)}\")\n        else:\n            print(f\"Dataset testing fold {fold} selection out of {len(self.df)}\")\n            self.df = self.df[self.df.fold == fold]\n            print(f\" for testing remains {len(self.df)}\")\n\n        self.clip_length = clip_length * sample_rate\n        self.sr = sample_rate\n        self.classes_num = classes_num\n        self.augment = augment\n        self.audiopath=audiopath\n        if augment:\n            print(f\"Will agument data from {meta_csv}\")\n\n    def __len__(self):\n        return len(self.df)\n\n    def __getitem__(self, index):\n        \"\"\"Load waveform and target of an audio clip.\n\n        Args:\n          meta: {\n            'hdf5_path': str,\n            'index_in_hdf5': int}\n        Returns:\n          data_dict: {\n            'audio_name': str,\n            'waveform': (clip_samples,),\n            'target': (classes_num,)}\n        \"\"\"\n        row = self.df.iloc[index]\n\n        #waveform = decode_mp3(np.fromfile(self.audiopath + row.filename, dtype='uint8'))\n        waveform, _ = librosa.load(self.audiopath + row.filename, sr=self.sr, mono=True)\n        if self.augment:\n            waveform = pydub_augment(waveform)\n        waveform = pad_or_truncate(waveform, self.clip_length)\n        waveform = self.resample(waveform)\n        target = row.target\n        return waveform.reshape(1, -1),  row.filename, target\n\n    def resample(self, waveform):\n        \"\"\"Resample.\n        Args:\n          waveform: (clip_samples,)\n        Returns:\n          (resampled_clip_samples,)\n        \"\"\"\n        if self.sample_rate == 32000:\n            return waveform\n        elif self.sample_rate == 16000:\n            return waveform[0:: 2]\n        elif self.sample_rate == 8000:\n            return waveform[0:: 4]\n        else:\n            raise Exception('Incorrect sample rate!')\n\n\n\n@dataset.command\ndef get_base_training_set(meta_csv, audio_path, fold=1):\n    ds = AudioSetDataset(meta_csv, audio_path, fold,  train=True, augment=True)\n    return ds\n\n\n@dataset.command\ndef get_ft_weighted_sampler(samples_weights=CMD(\".get_ft_cls_balanced_sample_weights\"),\n                            epoch_len=100000, sampler_replace=False):\n    num_nodes = int(os.environ.get('num_nodes', 1))\n    ddp = int(os.environ.get('DDP', 1))\n    num_nodes = max(ddp, num_nodes)\n    print(\"num_nodes= \", num_nodes)\n    rank = int(os.environ.get('NODE_RANK', 0))\n    return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,\n                                                                   num_samples=epoch_len, replacement=sampler_replace),\n                                     dataset=range(epoch_len),\n                                     num_replicas=num_nodes,\n                                     rank=rank,\n                                     )\n\n\n@dataset.command\ndef get_base_test_set(meta_csv, audio_path, fold=1):\n    ds = AudioSetDataset(meta_csv, audio_path, fold,  train=False)\n    return ds\n\n\n\n@dataset.command(prefix='roll_conf')\ndef get_roll_func(axis=1, shift=None, shift_range=50):\n    print(\"rolling...\")\n\n    def roll_func(b):\n        x, i, y = b\n        x = torch.as_tensor(x)\n        sf = shift\n        if shift is None:\n            sf = int(np.random.random_integers(-shift_range, shift_range))\n        global FirstTime\n\n        return x.roll(sf, axis), i, y\n\n    return roll_func\n\n\n@dataset.command\ndef get_training_set(normalize, roll, wavmix=False):\n    ds = get_base_training_set()\n    get_ir_sample()\n    if normalize:\n        print(\"normalized train!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    if roll:\n        ds = PreprocessDataset(ds, get_roll_func())\n    if wavmix:\n        ds = MixupDataset(ds)\n\n    return ds\n\n\n@dataset.command\ndef get_test_set(normalize):\n    ds = get_base_test_set()\n    if normalize:\n        print(\"normalized test!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    return ds\n\n\n@dataset.command\ndef print_conf(_config):\n    print(\"Config of \", dataset.path, id(dataset))\n    print(_config)\n    print()\n\n\nclass DistributedSamplerWrapper(DistributedSampler):\n    def __init__(\n            self, sampler, dataset,\n            num_replicas=None,\n            rank=None,\n            shuffle: bool = True):\n        super(DistributedSamplerWrapper, self).__init__(\n            dataset, num_replicas, rank, shuffle)\n        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238\n        self.sampler = sampler\n\n    def __iter__(self):\n        if self.sampler.generator is None:\n            self.sampler.generator = torch.Generator()\n        self.sampler.generator.manual_seed(self.seed + self.epoch)\n        indices = list(self.sampler)\n        if self.epoch == 0:\n            print(f\"\\n DistributedSamplerWrapper :  {indices[:10]} \\n\\n\")\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        return iter(indices)\n\n\nif __name__ == \"__main__\":\n    from sacred import Experiment\n\n    ex = Experiment(\"test_dataset\", ingredients=[dataset])\n\n\n    @ex.automain\n    def default_command():\n        ex.current_run.get_command_function(\"print_config\")()\n        get_base_training_set()\n        ds = get_test_set()\n        print(ds[0])\n        ds = get_training_set()\n        print(ds[0])\n        print(\"get_base_training_set\", len(get_base_training_set()))\n        print(\"get_base_test_set\", len(get_base_test_set()))\n        print(\"get_training_set\", len(get_training_set()))\n        print(\"get_test_set\", len(get_test_set()))\n"
  },
  {
    "path": "ex_audioset.py",
    "content": "import os\nimport sys\nimport PIL\nimport pytorch_lightning\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers import DynamicIngredient, CMD\nfrom torch.nn import functional as F\nimport numpy as np\nimport wandb\n\nfrom ba3l.experiment import Experiment\nfrom ba3l.module import Ba3lModule\n\nfrom torch.utils.data import DataLoader\n\nfrom config_updates import add_configs\nfrom helpers.mixup import my_mixup\nfrom helpers.models_size import count_non_zero_params\nfrom helpers.ramp import exp_warmup_linear_down, cosine_cycle\nfrom helpers.workersinit import worker_init_fn\nfrom sklearn import metrics\nfrom pytorch_lightning import Trainer as plTrainer\nfrom pytorch_lightning.loggers import WandbLogger\nfrom pytorch_lightning.callbacks import LearningRateMonitor\n\nex = Experiment(\"audioset\")\n\n# Example call with all the default config:\n# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c \"PaSST base\"\n# with 2 gpus:\n# DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c \"PaSST base 2 GPU\"\n\n# capture the config of the trainer with the prefix \"trainer\", this allows to use sacred to update PL trainer config\n# now you can use in the cmd trainer.precision=16 for example\nget_trainer = ex.command(plTrainer, prefix=\"trainer\")\n# capture the WandbLogger and prefix it with \"wandb\", this allows to use sacred to update WandbLogger config from the command line\nget_logger = ex.command(WandbLogger, prefix=\"wandb\")\nwandb_logger = None\n\n# define datasets and loaders\nget_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,\n                                             num_workers=16, shuffle=None, dataset=CMD(\"/basedataset.get_full_training_set\"),\n                                             sampler=CMD(\"/basedataset.get_ft_weighted_sampler\"))\n\nget_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),\n                                            validate=True, batch_size=20, num_workers=16,\n                                            dataset=CMD(\"/basedataset.get_test_set\"))\n\n\n@ex.config\ndef default_conf():\n    cmd = \" \".join(sys.argv)  # command line arguments\n    saque_cmd = os.environ.get(\"SAQUE_CMD\", \"\").strip()\n    saque_id = os.environ.get(\"SAQUE_ID\", \"\").strip()\n    slurm_job_id = os.environ.get(\"SLURM_JOB_ID\", \"\").strip()\n    if os.environ.get(\"SLURM_ARRAY_JOB_ID\", False):\n        slurm_job_id = os.environ.get(\"SLURM_ARRAY_JOB_ID\", \"\").strip() + \"_\" + os.environ.get(\"SLURM_ARRAY_TASK_ID\",\n                                                                                               \"\").strip()\n    process_id = os.getpid()\n    models = {\n        \"net\": DynamicIngredient(\"models.passt.model_ing\", arch=\"passt_deit_bd_p16_384\", n_classes=527, s_patchout_t=40,\n                                 s_patchout_f=4),  # network config\n        \"mel\": DynamicIngredient(\"models.preprocess.model_ing\",\n                                 instance_cmd=\"AugmentMelSTFT\",\n                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,\n                                 timem=192,\n                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,\n                                 fmax_aug_range=2000)\n    }\n    basedataset = DynamicIngredient(\"audioset.dataset.dataset\", wavmix=1)\n    wandb = dict(project=\"passt_audioset2\", log_model=True)\n    watch_model = True\n    trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, precision=16,\n                   reload_dataloaders_every_epoch=True)\n    lr = 0.00002  # learning rate\n    use_mixup = True\n    mixup_alpha = 0.3\n    compile = True # compile the model, requires pytorch >= 2.0\n\n\n# register extra possible configs\nadd_configs(ex)\n\n\n@ex.command\ndef get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,\n                         schedule_mode=\"exp_lin\"):\n    if schedule_mode == \"exp_lin\":\n        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)\n    if schedule_mode == \"cos_cyc\":\n        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)\n    raise RuntimeError(\n        f\"schedule_mode={schedule_mode} Unknown for a lambda funtion.\")\n\n\n@ex.command\ndef get_lr_scheduler(optimizer, schedule_mode):\n    if schedule_mode in {\"exp_lin\", \"cos_cyc\"}:\n        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown.\")\n\n\n@ex.command\ndef get_optimizer(params, lr, adamw=True, weight_decay=0.0001):\n    if adamw:\n        print(f\"\\nUsing adamw weight_decay={weight_decay}!\\n\")\n        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)\n    return torch.optim.Adam(params, lr=lr)\n\n\nclass M(Ba3lModule):\n    def __init__(self, experiment):\n        self.mel = None\n        self.da_net = None\n        super(M, self).__init__(experiment)\n\n        self.use_mixup = self.config.use_mixup or False\n        self.mixup_alpha = self.config.mixup_alpha\n\n        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)\n        self.experiment.info[\"start_sum_params\"] = sum_params\n        self.experiment.info[\"start_sum_params_non_zero\"] = sum_non_zero\n\n        # in case we need embedings for the DA\n        self.net.return_embed = True\n        self.dyn_norm = self.config.dyn_norm\n        self.do_swa = False\n\n        self.distributed_mode = self.config.trainer.num_nodes > 1\n        \n        if self.config.compile:\n            # pt 2 magic\n            print(\"\\n\\nCompiling the model pytorch 2... \\n\\n\")\n            self.net = torch.compile(self.net)\n            # compile only the net, not the mel\n            #self.mel = torch.compile(self.mel)\n\n    def forward(self, x):\n        return self.net(x)\n\n    def mel_forward(self, x):\n        old_shape = x.size()\n        x = x.reshape(-1, old_shape[2])\n        x = self.mel(x)\n        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])\n        if self.dyn_norm:\n            if not hasattr(self, \"tr_m\") or not hasattr(self, \"tr_std\"):\n                tr_m, tr_std = get_dynamic_norm(self)\n                self.register_buffer('tr_m', tr_m)\n                self.register_buffer('tr_std', tr_std)\n            x = (x - self.tr_m) / self.tr_std\n        return x\n\n    def training_step(self, batch, batch_idx):\n        # REQUIRED\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n        \n        if self.global_step < 5:\n            images = [ wandb.Image(\n                PIL.Image.fromarray((i * 255).astype(np.uint8)).convert(\"L\"),\n                caption=\"spectrograms\",\n            ) for i in x[:, 0, :, :].cpu().numpy()]\n            wandb.log({\"spectrograms\": images})\n            # wandb_logger.log_image(key=\"spectrograms\", images=[i for i in x[:,0,:,:].cpu().numpy()])\n\n        orig_x = x\n        batch_size = len(y)\n\n        rn_indices, lam = None, None\n        if self.use_mixup:\n            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)\n            lam = lam.to(x.device)\n            x = x * lam.reshape(batch_size, 1, 1, 1) + \\\n                x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))\n\n        y_hat, embed = self.forward(x)\n\n        if self.use_mixup:\n            y_mix = y * lam.reshape(batch_size, 1) + \\\n                y[rn_indices] * (1. - lam.reshape(batch_size, 1))\n            samples_loss = F.binary_cross_entropy_with_logits(\n                y_hat, y_mix, reduction=\"none\")\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n        else:\n            samples_loss = F.binary_cross_entropy_with_logits(\n                y_hat, y, reduction=\"none\")\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n\n        results = {\"loss\": loss, }\n        self.log('trainer/lr', self.trainer.optimizers[0].param_groups[0]['lr'])\n        self.log('epoch', self.current_epoch)\n        self.log(\"training.loss\", loss.detach())\n        return results\n\n    def training_epoch_end(self, outputs):\n        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()\n\n        logs = {'train.loss': avg_loss, # 'step': self.current_epoch\n                }\n\n        self.log_dict(logs, sync_dist=True)\n\n    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        y_hat, _ = self.forward(x)\n        return f, y_hat\n\n    def validation_step(self, batch, batch_idx):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        if self.global_step < 5:\n            images = [ wandb.Image(\n                PIL.Image.fromarray((i * 255).astype(np.uint8)).convert(\"L\"),\n                caption=\"validation_spectrograms\",\n            ) for i in x[:, 0, :, :].cpu().numpy()]\n            wandb.log({\"validation_spectrograms\": images})\n            # wandb_logger.log_image(key=\"validation_spectrograms\", images=[\n            #                        i for i in x[:, 0, :, :].cpu().numpy()])\n\n        results = {}\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            y_hat, _ = net(x)\n            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)\n            loss = samples_loss.mean()\n            out = torch.sigmoid(y_hat.detach())\n            # self.log(\"validation.loss\", loss, prog_bar=True, on_epoch=True, on_step=False)\n            results = {**results, net_name + \"val_loss\": loss,\n                       net_name + \"out\": out, net_name + \"target\": y.detach()}\n        results = {k: v.cpu() for k, v in results.items()}\n        return results\n\n    def validation_epoch_end(self, outputs):\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            avg_loss = torch.stack([x[net_name + 'val_loss']\n                                   for x in outputs]).mean()\n            out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)\n            target = torch.cat([x[net_name + 'target']\n                               for x in outputs], dim=0)\n            try:\n                average_precision = metrics.average_precision_score(\n                    target.float().numpy(), out.float().numpy(), average=None)\n            except ValueError:\n                average_precision = np.array([np.nan] * 527)\n            try:\n                roc = metrics.roc_auc_score(\n                    target.numpy(), out.numpy(), average=None)\n            except ValueError:\n                roc = np.array([np.nan] * 527)\n            logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),\n                    net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),\n                    net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),}\n                    #'step': torch.as_tensor(self.current_epoch).cuda()}\n            \n            # torch.save(average_precision,\n            #            f\"ap_perclass_{average_precision.mean()}.pt\")\n            # print(average_precision)\n            self.log_dict(logs, sync_dist=True)\n            if self.distributed_mode:\n                allout = self.all_gather(out)\n                alltarget = self.all_gather(target)\n\n                average_precision = metrics.average_precision_score(\n                    alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),\n                    allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)\n                if self.trainer.is_global_zero:\n                    logs = {net_name + \"allap\": torch.as_tensor(average_precision.mean()).cuda(),\n                           # 'step': torch.as_tensor(self.current_epoch).cuda()\n                            }\n                    self.log_dict(logs, sync_dist=False)\n            else:\n                self.log_dict(\n                    {net_name + \"allap\": logs[net_name + 'ap'], \n                     #'step': logs['step']\n                      }\n                    , sync_dist=True)\n\n    def configure_optimizers(self):\n        # REQUIRED\n        # can return multiple optimizers and learning_rate schedulers\n        # (LBFGS it is automatically supported, no need for closure function)\n        optimizer = get_optimizer(self.parameters())\n        # torch.optim.Adam(self.parameters(), lr=self.config.lr)\n        return {\n            'optimizer': optimizer,\n            'lr_scheduler': get_lr_scheduler(optimizer)\n        }\n\n    def configure_callbacks(self):\n        return get_extra_checkpoint_callback() + get_extra_swa_callback() + [LearningRateMonitor(logging_interval='epoch')]\n\n\n@ex.command\ndef get_dynamic_norm(model, dyn_norm=False):\n    if not dyn_norm:\n        return None, None\n    raise RuntimeError('no dynamic norm supported yet.')\n\n\n@ex.command\ndef get_extra_checkpoint_callback(save_last_n=None):\n    if save_last_n is None:\n        return []\n    return [ModelCheckpoint(monitor=\"step\", verbose=True, save_top_k=save_last_n, mode='max')]\n\n\n@ex.command\ndef get_extra_swa_callback(swa=True, swa_epoch_start=50,\n                           swa_freq=5):\n    if not swa:\n        return []\n    print(\"\\n Using swa!\\n\")\n    if pytorch_lightning.__version__ <= \"1.5\":\n        from helpers.swa_legacy import StochasticWeightAveraging\n    else:\n        from helpers.swa_callback import StochasticWeightAveraging\n    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]\n\n\n@ex.command\ndef main(_run, _config, _log, _rnd, _seed, watch_model=True):\n    global wandb_logger \n    wandb_logger = get_logger()\n    trainer = get_trainer(logger=wandb_logger)\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n\n    modul = M(ex)\n    if watch_model:\n        # change log frequency of gradients and parameters (100 steps by default)\n        wandb_logger.watch(modul, log_freq=1000, log=\"all\" )\n\n    if pytorch_lightning.__version__ <= \"1.5\":\n        trainer.fit(\n            modul,\n            train_dataloader=train_loader,\n            val_dataloaders=val_loader,\n        )\n    else:\n        trainer.fit(\n            modul,\n            train_dataloaders=train_loader,\n            val_dataloaders=val_loader,\n        )\n\n    return {\"done\": True}\n\n\n@ex.command\ndef model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=12):\n    '''\n    Test training speed of a model\n    @param _run:\n    @param _config:\n    @param _log:\n    @param _rnd:\n    @param _seed:\n    @param speed_test_batch_size: the batch size during the test\n    @return:\n    '''\n\n    modul = M(ex)\n    modul = modul.cuda()\n    batch_size = speed_test_batch_size\n    print(f\"\\nBATCH SIZE : {batch_size}\\n\")\n    test_length = 100\n    print(f\"\\ntest_length : {test_length}\\n\")\n\n    x = torch.ones([batch_size, 1, 128, 998]).cuda()\n    target = torch.ones([batch_size, 527]).cuda()\n    # one passe\n    net = modul.net\n    # net(x)\n    scaler = torch.cuda.amp.GradScaler()\n    torch.backends.cudnn.benchmark = True\n    # net = torch.jit.trace(net,(x,))\n    net = torch.compile(net)\n    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)\n\n    print(\"warmup\")\n    import time\n    torch.cuda.synchronize()\n    t1 = time.time()\n    for i in range(10):\n        with torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(\n                y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('warmup done:', (t2 - t1))\n    torch.cuda.synchronize()\n    t1 = time.time()\n    print(\"testing speed\")\n\n    for i in range(test_length):\n        with torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(\n                y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('test done:', (t2 - t1))\n    print(\"average speed: \", (test_length * batch_size) /\n          (t2 - t1), \" specs/second\")\n\n\n@ex.command\ndef evaluate_only(_run, _config, _log, _rnd, _seed):\n    # force overriding the config, not logged = not recommended\n    trainer = get_trainer(logger=get_logger())\n    val_loader = get_validate_loader()\n\n    modul = M(ex)\n    modul.val_dataloader = None\n    #trainer.val_dataloaders = None\n    print(f\"\\n\\nValidation len={len(val_loader)}\\n\")\n    res = trainer.validate(modul, dataloaders=val_loader)\n    print(\"\\n\\n Validtaion:\")\n    print(res)\n\n\n@ex.command\ndef test_loaders():\n    '''\n    get one sample from each loader for debbuging\n    @return:\n    '''\n    for i, b in enumerate(ex.datasets.training.get_iter()):\n        print(b)\n        break\n\n    for i, b in enumerate(ex.datasets.test.get_iter()):\n        print(b)\n        break\n\n\ndef set_default_json_pickle(obj):\n    if isinstance(obj, set):\n        return list(obj)\n    raise TypeError\n\n\n@ex.command\ndef preload_mp3(all_y=CMD(\"/basedataset.preload_mp3\")):\n    '''\n    read the dataset sequentially, useful if you have a network cache\n    @param all_y: the dataset preload command\n    @return:\n    '''\n    print(all_y.shape)\n\n\ndef multiprocessing_run(rank, word_size):\n    print(\"rank \", rank, os.getpid())\n    print(\"word_size \", word_size)\n    os.environ['NODE_RANK'] = str(rank)\n    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(\",\")[\n        rank]\n    argv = sys.argv\n    if rank != 0:\n        print(f\"Unobserved {os.getpid()} with rank {rank}\")\n        argv = argv + [\"-u\"]  # only rank 0 is observed\n    if \"with\" not in argv:\n        argv = argv + [\"with\"]\n\n    argv = argv + \\\n        [f\"trainer.num_nodes={word_size}\", f\"trainer.accelerator=ddp\"]\n    print(argv)\n\n    @ex.main\n    def default_command():\n        return main()\n\n    ex.run_commandline(argv)\n\n\nif __name__ == '__main__':\n    # set DDP=2 forks two processes to run on two GPUs\n    # the environment variable \"DDP\" define the number of processes to fork\n    # With two 2x 2080ti you can train the full model to .47 in around 24 hours\n    # you may need to set NCCL_P2P_DISABLE=1\n    word_size = os.environ.get(\"DDP\", None)\n    if word_size:\n        import random\n\n        word_size = int(word_size)\n        print(f\"\\n\\nDDP TRAINING WITH WORD_SIZE={word_size}\\n\\n\")\n        os.environ['MASTER_ADDR'] = '127.0.0.1'\n        # plz no collisions\n        os.environ['MASTER_PORT'] = f\"{9999 + random.randint(0, 9999)}\"\n        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'\n\n        for rank in range(word_size):\n            pid = os.fork()\n            if pid == 0:\n                print(\"Child Forked \")\n                multiprocessing_run(rank, word_size)\n                exit(0)\n\n        pid, exit_code = os.wait()\n        print(pid, exit_code)\n        exit(0)\n\nprint(\"__main__ is running pid\", os.getpid(), \"in module main: \", __name__)\n\n\n@ex.automain\ndef default_command():\n    return main()\n"
  },
  {
    "path": "ex_esc50.py",
    "content": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers import DynamicIngredient, CMD\nfrom torch.nn import functional as F\nimport numpy as np\n\nfrom ba3l.experiment import Experiment\nfrom ba3l.module import Ba3lModule\n\nfrom torch.utils.data import DataLoader\n\nfrom config_updates import add_configs\nfrom helpers.mixup import my_mixup\nfrom helpers.models_size import count_non_zero_params\nfrom helpers.ramp import exp_warmup_linear_down, cosine_cycle\nfrom helpers.workersinit import worker_init_fn\nfrom sklearn import metrics\nfrom pytorch_lightning import Trainer as plTrainer\nfrom pytorch_lightning.loggers import WandbLogger\n\n\n\nex = Experiment(\"passt_esc50\")\n\n# Example call with all the default config:\n# python ex_esc50.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c \"ESC50 PaSST base\"\n# with 2 gpus:\n# DDP=2 python ex_esc50.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c \"ESC50 PaSST base\"\n\n# capture the config of the trainer with the prefix \"trainer\", this allows to use sacred to update PL trainer config\nget_trainer = ex.command(plTrainer, prefix=\"trainer\")\n# capture the WandbLogger and prefix it with \"wandb\", this allows to use sacred to update WandbLogger config from the command line\nget_logger = ex.command(WandbLogger, prefix=\"wandb\")\n\n\n# define datasets and loaders\nget_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,\n                          num_workers=16, shuffle=None, dataset=CMD(\"/basedataset.get_training_set\"),\n                          )\n\nget_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),\n                                            validate=True, batch_size=20, num_workers=16,\n                                            dataset=CMD(\"/basedataset.get_test_set\"))\n\n\n@ex.config\ndef default_conf():\n    cmd = \" \".join(sys.argv)\n    saque_cmd = os.environ.get(\"SAQUE_CMD\", \"\").strip()\n    saque_id = os.environ.get(\"SAQUE_ID\", \"\").strip()\n    slurm_job_id = os.environ.get(\"SLURM_JOB_ID\", \"\").strip()\n    if os.environ.get(\"SLURM_ARRAY_JOB_ID\", False):\n        slurm_job_id = os.environ.get(\"SLURM_ARRAY_JOB_ID\", \"\").strip() + \"_\" + os.environ.get(\"SLURM_ARRAY_TASK_ID\",\n                                                                                               \"\").strip()\n    process_id = os.getpid()\n    models = {\n        \"net\": DynamicIngredient(\"models.passt.model_ing\", n_classes=50, s_patchout_t=10, s_patchout_f=3),\n        \"mel\": DynamicIngredient(\"models.preprocess.model_ing\",\n                                 instance_cmd=\"AugmentMelSTFT\",\n                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,\n                                 timem=80,\n                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,\n                                 fmax_aug_range=2000)\n    }\n    basedataset = DynamicIngredient(\"esc50.dataset.dataset\")\n    trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,\n                   reload_dataloaders_every_epoch=True)\n    wandb = dict(project=\"passt_esc50\", log_model=True)\n    lr = 0.00001\n    use_mixup = True\n    mixup_alpha = 0.3\n\n\n# register extra possible configs\nadd_configs(ex)\n\n\n@ex.command\ndef get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,\n                         schedule_mode=\"exp_lin\"):\n    if schedule_mode == \"exp_lin\":\n        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)\n    if schedule_mode == \"cos_cyc\":\n        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown for a lambda funtion.\")\n\n\n@ex.command\ndef get_lr_scheduler(optimizer, schedule_mode):\n    if schedule_mode in {\"exp_lin\", \"cos_cyc\"}:\n        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown.\")\n\n\n@ex.command\ndef get_optimizer(params, lr, adamw=True, weight_decay=0.0001):\n    if adamw:\n        print(f\"\\nUsing adamw weight_decay={weight_decay}!\\n\")\n        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)\n    return torch.optim.Adam(params, lr=lr)\n\n\nclass M(Ba3lModule):\n    def __init__(self, experiment):\n        self.mel = None\n        self.da_net = None\n        super(M, self).__init__(experiment)\n\n        self.use_mixup = self.config.use_mixup or False\n        self.mixup_alpha = self.config.mixup_alpha\n\n        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)\n        self.experiment.info[\"start_sum_params\"] = sum_params\n        self.experiment.info[\"start_sum_params_non_zero\"] = sum_non_zero\n\n        # in case we need embedings for the DA\n        self.net.return_embed = True\n        self.dyn_norm = self.config.dyn_norm\n        self.do_swa = False\n\n        self.distributed_mode = self.config.trainer.num_nodes > 1\n\n    def forward(self, x):\n        return self.net(x)\n\n    def mel_forward(self, x):\n        old_shape = x.size()\n        x = x.reshape(-1, old_shape[2])\n        x = self.mel(x)\n        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])\n        if self.dyn_norm:\n            if not hasattr(self, \"tr_m\") or not hasattr(self, \"tr_std\"):\n                tr_m, tr_std = get_dynamic_norm(self)\n                self.register_buffer('tr_m', tr_m)\n                self.register_buffer('tr_std', tr_std)\n            x = (x - self.tr_m) / self.tr_std\n        return x\n\n    def training_step(self, batch, batch_idx):\n        # REQUIRED\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        orig_x = x\n        batch_size = len(y)\n\n        rn_indices, lam = None, None\n        if self.use_mixup:\n            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)\n            lam = lam.to(x.device)\n            x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))\n\n        y_hat, embed = self.forward(x)\n\n        if self.use_mixup:\n            # y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))\n            samples_loss = (F.cross_entropy(y_hat, y, reduction=\"none\") * lam.reshape(batch_size) +\n                            F.cross_entropy(y_hat, y[rn_indices], reduction=\"none\") * (1. - lam.reshape(batch_size)))\n            loss = samples_loss.mean()\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n        else:\n            samples_loss = F.cross_entropy(y_hat, y, reduction=\"none\")\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n\n        results = {\"loss\": loss, }\n\n        return results\n\n    def training_epoch_end(self, outputs):\n        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()\n\n        logs = {'train.loss': avg_loss, 'step': self.current_epoch}\n\n        self.log_dict(logs, sync_dist=True)\n\n    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        y_hat, _ = self.forward(x)\n        return f, y_hat\n\n    def validation_step(self, batch, batch_idx):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        results = {}\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            y_hat, _ = net(x)\n            samples_loss = F.cross_entropy(y_hat, y, reduction=\"none\")\n            loss = samples_loss.mean()\n            _, preds = torch.max(y_hat, dim=1)\n            n_correct_pred_per_sample = (preds == y)\n            n_correct_pred = n_correct_pred_per_sample.sum()\n            # self.log(\"validation.loss\", loss, prog_bar=True, on_epoch=True, on_step=False)\n            results = {**results, net_name + \"val_loss\": loss,\n                       net_name + \"n_correct_pred\": torch.as_tensor(n_correct_pred), net_name + \"n_pred\":torch.as_tensor( len(y)) }\n        results = {k: v.cpu() for k, v in results.items()}\n        return results\n\n    def validation_epoch_end(self, outputs):\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()\n            val_acc = sum([x['n_correct_pred'] for x in outputs]) * 1.0 / sum(x['n_pred'] for x in outputs)\n            logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),\n                    net_name + 'acc': torch.as_tensor(val_acc).cuda(),\n                    'step': torch.as_tensor(self.current_epoch).cuda()}\n            self.log_dict(logs, sync_dist=True)\n\n    def configure_optimizers(self):\n        # REQUIRED\n        # can return multiple optimizers and learning_rate schedulers\n        # (LBFGS it is automatically supported, no need for closure function)\n        optimizer = get_optimizer(self.parameters())\n        # torch.optim.Adam(self.parameters(), lr=self.config.lr)\n        return {\n            'optimizer': optimizer,\n            'lr_scheduler': get_lr_scheduler(optimizer)\n        }\n\n    def configure_callbacks(self):\n        return get_extra_checkpoint_callback() + get_extra_swa_callback()\n\n\n@ex.command\ndef get_dynamic_norm(model, dyn_norm=False):\n    if not dyn_norm:\n        return None, None\n    raise RuntimeError('no dynamic norm supported yet.')\n\n\n@ex.command\ndef get_extra_checkpoint_callback(save_last_n=None):\n    if save_last_n is None:\n        return []\n    return [ModelCheckpoint(monitor=\"step\", verbose=True, save_top_k=save_last_n, mode='max')]\n\n\n@ex.command\ndef get_extra_swa_callback(swa=True, swa_epoch_start=2,\n                           swa_freq=1):\n    if not swa:\n        return []\n    print(\"\\n Using swa!\\n\")\n    from helpers.swa_callback import StochasticWeightAveraging\n    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]\n\n\n@ex.command\ndef main(_run, _config, _log, _rnd, _seed):\n    trainer = get_trainer()\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n\n    modul = M(ex)\n\n    trainer.fit(\n        modul,\n        train_dataloaders=train_loader,\n        val_dataloaders=val_loader,\n    )\n\n    return {\"done\": True}\n\n\n@ex.command\ndef model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):\n    '''\n    Test training speed of a model\n    @param _run:\n    @param _config:\n    @param _log:\n    @param _rnd:\n    @param _seed:\n    @param speed_test_batch_size: the batch size during the test\n    @return:\n    '''\n\n    modul = M(ex)\n    modul = modul.cuda()\n    batch_size = speed_test_batch_size\n    print(f\"\\nBATCH SIZE : {batch_size}\\n\")\n    test_length = 100\n    print(f\"\\ntest_length : {test_length}\\n\")\n\n    x = torch.ones([batch_size, 1, 128, 998]).cuda()\n    target = torch.ones([batch_size, 527]).cuda()\n    # one passe\n    net = modul.net\n    # net(x)\n    scaler = torch.cuda.amp.GradScaler()\n    torch.backends.cudnn.benchmark = True\n    # net = torch.jit.trace(net,(x,))\n    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)\n\n    print(\"warmup\")\n    import time\n    torch.cuda.synchronize()\n    t1 = time.time()\n    for i in range(10):\n        with  torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('warmup done:', (t2 - t1))\n    torch.cuda.synchronize()\n    t1 = time.time()\n    print(\"testing speed\")\n\n    for i in range(test_length):\n        with  torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('test done:', (t2 - t1))\n    print(\"average speed: \", (test_length * batch_size) / (t2 - t1), \" specs/second\")\n\n\n@ex.command\ndef evaluate_only(_run, _config, _log, _rnd, _seed):\n    # force overriding the config, not logged = not recommended\n    trainer = get_trainer()\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n\n    modul = M(ex)\n    modul.val_dataloader = None\n    #trainer.val_dataloaders = None\n    print(f\"\\n\\nValidation len={len(val_loader)}\\n\")\n    res = trainer.validate(modul, dataloaders=val_loader)\n    print(\"\\n\\n Validtaion:\")\n    print(res)\n\n\n@ex.command\ndef test_loaders():\n    '''\n    get one sample from each loader for debbuging\n    @return:\n    '''\n    for i, b in enumerate(ex.datasets.training.get_iter()):\n        print(b)\n        break\n\n    for i, b in enumerate(ex.datasets.test.get_iter()):\n        print(b)\n        break\n\n\ndef set_default_json_pickle(obj):\n    if isinstance(obj, set):\n        return list(obj)\n    raise TypeError\n\n\n\ndef multiprocessing_run(rank, word_size):\n    print(\"rank \", rank, os.getpid())\n    print(\"word_size \", word_size)\n    os.environ['NODE_RANK'] = str(rank)\n    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(\",\")[rank]\n    argv = sys.argv\n    if rank != 0:\n        print(f\"Unobserved {os.getpid()} with rank {rank}\")\n        argv = argv + [\"-u\"]  # only rank 0 is observed\n    if \"with\" not in argv:\n        argv = argv + [\"with\"]\n\n    argv = argv + [f\"trainer.num_nodes={word_size}\", f\"trainer.accelerator=ddp\"]\n    print(argv)\n\n    @ex.main\n    def default_command():\n        return main()\n\n    ex.run_commandline(argv)\n\n\nif __name__ == '__main__':\n    # set DDP=2 forks two processes to run on two GPUs\n    # the environment variable \"DDP\" define the number of processes to fork\n    # With two 2x 2080ti you can train the full model to .47 in around 24 hours\n    # you may need to set NCCL_P2P_DISABLE=1\n    word_size = os.environ.get(\"DDP\", None)\n    if word_size:\n        import random\n\n        word_size = int(word_size)\n        print(f\"\\n\\nDDP TRAINING WITH WORD_SIZE={word_size}\\n\\n\")\n        os.environ['MASTER_ADDR'] = '127.0.0.1'\n        os.environ['MASTER_PORT'] = f\"{9999 + random.randint(0, 9999)}\"  # plz no collisions\n        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'\n\n        for rank in range(word_size):\n            pid = os.fork()\n            if pid == 0:\n                print(\"Child Forked \")\n                multiprocessing_run(rank, word_size)\n                exit(0)\n\n        pid, exit_code = os.wait()\n        print(pid, exit_code)\n        exit(0)\n\nprint(\"__main__ is running pid\", os.getpid(), \"in module main: \", __name__)\n\n\n@ex.automain\ndef default_command():\n    return main()\n"
  },
  {
    "path": "ex_fsd50k.py",
    "content": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers import DynamicIngredient, CMD\nfrom torch.nn import functional as F\nimport numpy as np\n\nfrom ba3l.experiment import Experiment\nfrom ba3l.module import Ba3lModule\n\nfrom torch.utils.data import DataLoader\n\nfrom config_updates import add_configs\nfrom helpers.mixup import my_mixup\nfrom helpers.models_size import count_non_zero_params\nfrom helpers.ramp import exp_warmup_linear_down, cosine_cycle\nfrom helpers.workersinit import worker_init_fn\nfrom sklearn import metrics\nfrom pytorch_lightning import Trainer as plTrainer\nfrom pytorch_lightning.loggers import WandbLogger\n\n\nex = Experiment(\"fsd50k\")\n\n# capture the config of the trainer with the prefix \"trainer\", this allows to use sacred to update PL trainer config\nget_trainer = ex.command(plTrainer, prefix=\"trainer\")\n# capture the WandbLogger and prefix it with \"wandb\", this allows to use sacred to update WandbLogger config from the command line\nget_logger = ex.command(WandbLogger, prefix=\"wandb\")\n\n\n\n# Example call with all the default config:\n# python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c \"PaSST base\"\n# with 2 gpus:\n# DDP=2 python ex_audioset.py with trainer.precision=16  models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c \"PaSST base 2 GPU\"\n\n# define datasets and loaders\nget_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,\n                          num_workers=16, shuffle=True, dataset=CMD(\"/basedataset.get_training_set\"),\n                          )\n\nget_validate_loader = ex.datasets.valid.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),\n                                             validate=True, batch_size=10, num_workers=16,\n                                             dataset=CMD(\"/basedataset.get_valid_set\"))\n\nget_eval_loader = ex.datasets.eval.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),\n                                        validate=True, batch_size=10, num_workers=16,\n                                        dataset=CMD(\"/basedataset.get_eval_set\"))\n\n\n@ex.named_config\ndef variable_eval():\n    basedataset = dict(variable_eval=True)\n    datasets = dict(valid=dict(batch_size=1), eval=dict(batch_size=1))\n\n\n@ex.config\ndef default_conf():\n    cmd = \" \".join(sys.argv)  # command line arguments\n    saque_cmd = os.environ.get(\"SAQUE_CMD\", \"\").strip()\n    saque_id = os.environ.get(\"SAQUE_ID\", \"\").strip()\n    slurm_job_id = os.environ.get(\"SLURM_JOB_ID\", \"\").strip()\n    if os.environ.get(\"SLURM_ARRAY_JOB_ID\", False):\n        slurm_job_id = os.environ.get(\"SLURM_ARRAY_JOB_ID\", \"\").strip() + \"_\" + os.environ.get(\"SLURM_ARRAY_TASK_ID\",\n                                                                                               \"\").strip()\n    process_id = os.getpid()\n    models = {\n        \"net\": DynamicIngredient(\"models.passt.model_ing\",\n                                 n_classes=200, s_patchout_t=10, s_patchout_f=4),  # network config\n        \"mel\": DynamicIngredient(\"models.preprocess.model_ing\",\n                                 instance_cmd=\"AugmentMelSTFT\",\n                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=0,\n                                 timem=0,\n                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,\n                                 fmax_aug_range=2000)\n    }\n    # set the default name for wandb logger\n    wandb = dict(project=\"passt_fsd50k\", log_model=True)\n    basedataset = DynamicIngredient(\"fsd50k.dataset.dataset\", wavmix=1)\n    trainer = dict(max_epochs=50, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,\n                   reload_dataloaders_every_epoch=True)\n    lr = 0.00001  # learning rate\n    use_mixup = True\n    mixup_alpha = 0.3\n\n\n# register extra possible configs\nadd_configs(ex)\n\n\n@ex.command\ndef get_scheduler_lambda(warm_up_len=5, ramp_down_start=10, ramp_down_len=10, last_lr_value=0.01,\n                         schedule_mode=\"exp_lin\"):\n    if schedule_mode == \"exp_lin\":\n        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)\n    if schedule_mode == \"cos_cyc\":\n        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown for a lambda funtion.\")\n\n\n@ex.command\ndef get_lr_scheduler(optimizer, schedule_mode):\n    if schedule_mode in {\"exp_lin\", \"cos_cyc\"}:\n        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown.\")\n\n\n@ex.command\ndef get_optimizer(params, lr, adamw=True, weight_decay=0.0001):\n    if adamw:\n        print(f\"\\nUsing adamw weight_decay={weight_decay}!\\n\")\n        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)\n    return torch.optim.Adam(params, lr=lr)\n\n\nclass M(Ba3lModule):\n    def __init__(self, experiment):\n        self.mel = None\n        self.da_net = None\n        super(M, self).__init__(experiment)\n\n        self.use_mixup = self.config.use_mixup or False\n        self.mixup_alpha = self.config.mixup_alpha\n\n        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)\n        self.experiment.info[\"start_sum_params\"] = sum_params\n        self.experiment.info[\"start_sum_params_non_zero\"] = sum_non_zero\n\n        # in case we need embedings for the DA\n        self.net.return_embed = True\n        self.dyn_norm = self.config.dyn_norm\n        self.do_swa = False\n        self.valid_names = [\"valid\", \"eval\"]\n        self.distributed_mode = self.config.trainer.num_nodes > 1\n\n    def forward(self, x):\n        return self.net(x)\n\n    def mel_forward(self, x):\n        old_shape = x.size()\n        x = x.reshape(-1, old_shape[2])\n        x = self.mel(x)\n        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])\n        if self.dyn_norm:\n            if not hasattr(self, \"tr_m\") or not hasattr(self, \"tr_std\"):\n                tr_m, tr_std = get_dynamic_norm(self)\n                self.register_buffer('tr_m', tr_m)\n                self.register_buffer('tr_std', tr_std)\n            x = (x - self.tr_m) / self.tr_std\n        return x\n\n    def training_step(self, batch, batch_idx):\n        # REQUIRED\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        orig_x = x\n        batch_size = len(y)\n\n        rn_indices, lam = None, None\n        if self.use_mixup:\n            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)\n            lam = lam.to(x.device)\n            x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))\n\n        y_hat, embed = self.forward(x)\n\n        if self.use_mixup:\n            y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))\n            samples_loss = F.binary_cross_entropy_with_logits(\n                y_hat, y_mix, reduction=\"none\")\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n        else:\n            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction=\"none\")\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n\n        results = {\"loss\": loss, }\n\n        return results\n\n    def training_epoch_end(self, outputs):\n        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()\n\n        logs = {'train.loss': avg_loss, 'step': self.current_epoch}\n\n        self.log_dict(logs, sync_dist=True)\n\n    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        y_hat, _ = self.forward(x)\n        return f, y_hat\n\n    def validation_step(self, batch, batch_idx, dataloader_idx):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        results = {}\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            y_hat, _ = net(x)\n            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)\n            loss = samples_loss.mean()\n            out = torch.sigmoid(y_hat.detach())\n            # self.log(\"validation.loss\", loss, prog_bar=True, on_epoch=True, on_step=False)\n            results = {**results, net_name + \"val_loss\": loss, net_name + \"out\": out, net_name + \"target\": y.detach()}\n        results = {k: v.cpu() for k, v in results.items()}\n        return results\n\n    def validation_epoch_end(self, outputs):\n        for idx, one_outputs in enumerate(outputs):\n            set_name = self.valid_names[idx] + \"_\"\n            model_name = [(\"\", self.net)]\n            if self.do_swa:\n                model_name = model_name + [(\"swa_\", self.net_swa)]\n            for net_name, net in model_name:\n                avg_loss = torch.stack([x[net_name + 'val_loss'] for x in one_outputs]).mean()\n                out = torch.cat([x[net_name + 'out'] for x in one_outputs], dim=0)\n                target = torch.cat([x[net_name + 'target'] for x in one_outputs], dim=0)\n                try:\n                    average_precision = metrics.average_precision_score(\n                        target.float().numpy(), out.float().numpy(), average=None)\n                except ValueError:\n                    average_precision = np.array([np.nan] * 200)\n                try:\n                    roc = metrics.roc_auc_score(target.numpy(), out.numpy(), average=None)\n                except ValueError:\n                    roc = np.array([np.nan] * 200)\n                logs = {set_name + net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),\n                        set_name + net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),\n                        set_name + net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),\n                        'step': torch.as_tensor(self.current_epoch).cuda()}\n                # torch.save(average_precision, f\"ap_perclass_{average_precision.mean()}.pt\")\n                # print(average_precision)\n                self.log_dict(logs, sync_dist=True)\n                if self.distributed_mode:\n                    allout = self.all_gather(out)\n                    alltarget = self.all_gather(target)\n\n                    average_precision = metrics.average_precision_score(\n                        alltarget.reshape(-1, alltarget.shape[-1]).cpu().numpy(),\n                        allout.reshape(-1, allout.shape[-1]).cpu().numpy(), average=None)\n                    if self.trainer.is_global_zero:\n                        logs = {set_name + net_name + \"allap\": torch.as_tensor(average_precision.mean()).cuda(),\n                                'step': torch.as_tensor(self.current_epoch).cuda()}\n                        self.log_dict(logs, sync_dist=False)\n                else:\n                    self.log_dict(\n                        {set_name + net_name + \"allap\": logs[set_name + net_name + 'ap'], 'step': logs['step']},\n                        sync_dist=True)\n\n    def configure_optimizers(self):\n        # REQUIRED\n        # can return multiple optimizers and learning_rate schedulers\n        # (LBFGS it is automatically supported, no need for closure function)\n        optimizer = get_optimizer(self.parameters())\n        # torch.optim.Adam(self.parameters(), lr=self.config.lr)\n        return {\n            'optimizer': optimizer,\n            'lr_scheduler': get_lr_scheduler(optimizer)\n        }\n\n    def configure_callbacks(self):\n        return get_extra_checkpoint_callback() + get_extra_swa_callback()\n\n\n@ex.command\ndef get_dynamic_norm(model, dyn_norm=False):\n    if not dyn_norm:\n        return None, None\n    raise RuntimeError('no dynamic norm supported yet.')\n\n\nmodel_checkpoint_callback = None\n\n\n@ex.command\ndef get_extra_checkpoint_callback(save_best=None):\n    if save_best is None:\n        return []\n    global model_checkpoint_callback\n    model_checkpoint_callback = ModelCheckpoint(monitor=\"allap\", verbose=True, save_top_k=save_best, mode='max',\n                                                every_n_val_epochs=1, every_n_train_steps=0)\n    return [model_checkpoint_callback]\n\n\n@ex.command\ndef get_extra_swa_callback(swa=True, swa_epoch_start=10,\n                           swa_freq=3):\n    if not swa:\n        return []\n    print(\"\\n Using swa!\\n\")\n    from helpers.swa_callback import StochasticWeightAveraging\n    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]\n\n\n@ex.command\ndef main(_run, _config, _log, _rnd, _seed):\n    trainer = get_trainer(logger=get_logger())\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n    eval_loader = get_eval_loader()\n\n    # eval_loader = get_eval_loader()\n\n    modul = M(ex)\n\n    trainer.fit(\n        modul,\n        train_dataloader=train_loader,\n        val_dataloaders=[val_loader, eval_loader],\n    )\n    ## evaluate best model on eval set\n    #trainer.val_dataloaders = None\n    modul.val_dataloaders = None\n\n    return {\"done\": True}\n\n\n@ex.command\ndef model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):\n    '''\n    Test training speed of a model\n    @param _run:\n    @param _config:\n    @param _log:\n    @param _rnd:\n    @param _seed:\n    @param speed_test_batch_size: the batch size during the test\n    @return:\n    '''\n\n    modul = M(ex)\n    modul = modul.cuda()\n    batch_size = speed_test_batch_size\n    print(f\"\\nBATCH SIZE : {batch_size}\\n\")\n    test_length = 100\n    print(f\"\\ntest_length : {test_length}\\n\")\n\n    x = torch.ones([batch_size, 1, 128, 998]).cuda()\n    target = torch.ones([batch_size, 527]).cuda()\n    # one passe\n    net = modul.net\n    # net(x)\n    scaler = torch.cuda.amp.GradScaler()\n    torch.backends.cudnn.benchmark = True\n    # net = torch.jit.trace(net,(x,))\n    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)\n\n    print(\"warmup\")\n    import time\n    torch.cuda.synchronize()\n    t1 = time.time()\n    for i in range(10):\n        with  torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('warmup done:', (t2 - t1))\n    torch.cuda.synchronize()\n    t1 = time.time()\n    print(\"testing speed\")\n\n    for i in range(test_length):\n        with  torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('test done:', (t2 - t1))\n    print(\"average speed: \", (test_length * batch_size) / (t2 - t1), \" specs/second\")\n\n\n@ex.command\ndef evaluate_only(_run, _config, _log, _rnd, _seed):\n    # force overriding the config, not logged = not recommended\n    trainer = get_trainer()\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n    modul = M(ex)\n    modul.val_dataloader = None\n    #trainer.val_dataloaders = None\n    print(f\"\\n\\nValidation len={len(val_loader)}\\n\")\n    res = trainer.validate(modul, dataloaders=val_loader)\n    print(\"\\n\\n Validtaion:\")\n    print(res)\n\n\n@ex.command\ndef test_loaders():\n    '''\n    get one sample from each loader for debbuging\n    @return:\n    '''\n    for i, b in enumerate(ex.datasets.training.get_iter()):\n        print(b)\n        break\n\n    for i, b in enumerate(ex.datasets.test.get_iter()):\n        print(b)\n        break\n\n\ndef set_default_json_pickle(obj):\n    if isinstance(obj, set):\n        return list(obj)\n    raise TypeError\n\n\n@ex.command\ndef preload_mp3(all_y=CMD(\"/basedataset.preload_mp3\")):\n    '''\n    read the dataset sequentially, useful if you have a network cache\n    @param all_y: the dataset preload command\n    @return:\n    '''\n    print(all_y.shape)\n\n\ndef multiprocessing_run(rank, word_size):\n    print(\"rank \", rank, os.getpid())\n    print(\"word_size \", word_size)\n    os.environ['NODE_RANK'] = str(rank)\n    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(\",\")[rank]\n    argv = sys.argv\n    if rank != 0:\n        print(f\"Unobserved {os.getpid()} with rank {rank}\")\n        argv = argv + [\"-u\"]  # only rank 0 is observed\n    if \"with\" not in argv:\n        argv = argv + [\"with\"]\n\n    argv = argv + [f\"trainer.num_nodes={word_size}\", f\"trainer.accelerator=ddp\"]\n    print(argv)\n\n    @ex.main\n    def default_command():\n        return main()\n\n    ex.run_commandline(argv)\n\n\nif __name__ == '__main__':\n    # set DDP=2 forks two processes to run on two GPUs\n    # the environment variable \"DDP\" define the number of processes to fork\n    # With two 2x 2080ti you can train the full model to .47 in around 24 hours\n    # you may need to set NCCL_P2P_DISABLE=1\n    word_size = os.environ.get(\"DDP\", None)\n    if word_size:\n        import random\n\n        word_size = int(word_size)\n        print(f\"\\n\\nDDP TRAINING WITH WORD_SIZE={word_size}\\n\\n\")\n        os.environ['MASTER_ADDR'] = '127.0.0.1'\n        os.environ['MASTER_PORT'] = f\"{9999 + random.randint(0, 9999)}\"  # plz no collisions\n        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'\n\n        for rank in range(word_size):\n            pid = os.fork()\n            if pid == 0:\n                print(\"Child Forked \")\n                multiprocessing_run(rank, word_size)\n                exit(0)\n\n        pid, exit_code = os.wait()\n        print(pid, exit_code)\n        exit(0)\n\nprint(\"__main__ is running pid\", os.getpid(), \"in module main: \", __name__)\n\n\n@ex.automain\ndef default_command():\n    return main()\n"
  },
  {
    "path": "ex_openmic.py",
    "content": "import os\nimport sys\n\nimport torch\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom sacred.config_helpers import DynamicIngredient, CMD\nfrom torch.nn import functional as F\nimport numpy as np\n\nfrom ba3l.experiment import Experiment\nfrom ba3l.module import Ba3lModule\n\nfrom torch.utils.data import DataLoader\n\nfrom config_updates import add_configs\nfrom helpers.mixup import my_mixup\nfrom helpers.models_size import count_non_zero_params\nfrom helpers.ramp import exp_warmup_linear_down, cosine_cycle\nfrom helpers.workersinit import worker_init_fn\nfrom sklearn import metrics\nfrom pytorch_lightning import Trainer as plTrainer\nfrom pytorch_lightning.loggers import WandbLogger\n\n\n\n\n\n\n\nex = Experiment(\"openmic\")\n\n# capture the config of the trainer with the prefix \"trainer\", this allows to use sacred to update PL trainer config\nget_trainer = ex.command(plTrainer, prefix=\"trainer\")\n# capture the WandbLogger and prefix it with \"wandb\", this allows to use sacred to update WandbLogger config from the command line\nget_logger = ex.command(WandbLogger, prefix=\"wandb\")\n\n\n# Example call with all the default config:\n# python ex_openmic.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c \"OpenMIC PaSST base\"\n# with 2 gpus:\n# DDP=2 python ex_openmic.py with  trainer.precision=16  -p -m mongodb_server:27000:audioset21_balanced -c \"OpenMIC PaSST base\"\n\n# define datasets and loaders\nget_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=6,\n                          num_workers=16, shuffle=None, dataset=CMD(\"/basedataset.get_training_set\"),\n                          )\n\nget_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),\n                                            validate=True, batch_size=20, num_workers=16,\n                                            dataset=CMD(\"/basedataset.get_test_set\"))\n\n\n@ex.config\ndef default_conf():\n    cmd = \" \".join(sys.argv)\n    saque_cmd = os.environ.get(\"SAQUE_CMD\", \"\").strip()\n    saque_id = os.environ.get(\"SAQUE_ID\", \"\").strip()\n    slurm_job_id = os.environ.get(\"SLURM_JOB_ID\", \"\").strip()\n    if os.environ.get(\"SLURM_ARRAY_JOB_ID\", False):\n        slurm_job_id = os.environ.get(\"SLURM_ARRAY_JOB_ID\", \"\").strip() + \"_\" + os.environ.get(\"SLURM_ARRAY_TASK_ID\",\n                                                                                               \"\").strip()\n    process_id = os.getpid()\n    models = {\n        \"net\": DynamicIngredient(\"models.passt.model_ing\", n_classes=20, s_patchout_t=40, s_patchout_f=4),\n        \"mel\": DynamicIngredient(\"models.preprocess.model_ing\",\n                                 instance_cmd=\"AugmentMelSTFT\",\n                                 n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,\n                                 timem=192,\n                                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,\n                                 fmax_aug_range=2000)\n    }\n    wandb = dict(project=\"passt_openmic\", log_model=True)\n    basedataset = DynamicIngredient(\"openmic.dataset.dataset\", wavmix=1)\n    # set the default for the trainer\n    trainer = dict(max_epochs=10, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,\n                   reload_dataloaders_every_epoch=True)\n    lr = 0.00001\n    use_mixup = True\n    mixup_alpha = 0.3\n\n\n\n# register extra possible configs\nadd_configs(ex)\n\n\n@ex.command\ndef get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, last_lr_value=0.01,\n                         schedule_mode=\"exp_lin\"):\n    if schedule_mode == \"exp_lin\":\n        return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)\n    if schedule_mode == \"cos_cyc\":\n        return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown for a lambda funtion.\")\n\n\n@ex.command\ndef get_lr_scheduler(optimizer, schedule_mode):\n    if schedule_mode in {\"exp_lin\", \"cos_cyc\"}:\n        return torch.optim.lr_scheduler.LambdaLR(optimizer, get_scheduler_lambda())\n    raise RuntimeError(f\"schedule_mode={schedule_mode} Unknown.\")\n\n\n@ex.command\ndef get_optimizer(params, lr, adamw=True, weight_decay=0.0001):\n    if adamw:\n        print(f\"\\nUsing adamw weight_decay={weight_decay}!\\n\")\n        return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)\n    return torch.optim.Adam(params, lr=lr)\n\n\nclass M(Ba3lModule):\n    def __init__(self, experiment):\n        self.mel = None\n        self.da_net = None\n        super(M, self).__init__(experiment)\n\n        self.use_mixup = self.config.use_mixup or False\n        self.mixup_alpha = self.config.mixup_alpha\n\n        desc, sum_params, sum_non_zero = count_non_zero_params(self.net)\n        self.experiment.info[\"start_sum_params\"] = sum_params\n        self.experiment.info[\"start_sum_params_non_zero\"] = sum_non_zero\n\n        # in case we need embedings for the DA\n        self.net.return_embed = True\n        self.dyn_norm = self.config.dyn_norm\n        self.do_swa = False\n\n        self.distributed_mode = self.config.trainer.num_nodes > 1\n        \n\n        \n\n    def forward(self, x):\n        return self.net(x)\n\n    def mel_forward(self, x):\n        old_shape = x.size()\n        x = x.reshape(-1, old_shape[2])\n        x = self.mel(x)\n        x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])\n        if self.dyn_norm:\n            if not hasattr(self, \"tr_m\") or not hasattr(self, \"tr_std\"):\n                tr_m, tr_std = get_dynamic_norm(self)\n                self.register_buffer('tr_m', tr_m)\n                self.register_buffer('tr_std', tr_std)\n            x = (x - self.tr_m) / self.tr_std\n        return x\n\n    def training_step(self, batch, batch_idx):\n        # REQUIRED\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        y_mask = y[:, 20:]\n        y = y[:, :20] > 0.5\n        y = y.float()\n\n        orig_x = x\n        batch_size = len(y)\n\n        rn_indices, lam = None, None\n        if self.use_mixup:\n            rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)\n            lam = lam.to(x.device)\n            x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))\n\n        y_hat, embed = self.forward(x)\n\n        if self.use_mixup:\n            y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))\n            samples_loss = F.binary_cross_entropy_with_logits(\n                y_hat, y_mix, reduction=\"none\")\n            y_mix_mask = ((y_mask > 0.5) | (y_mask[rn_indices] > 0.5)).float()\n            samples_loss = y_mask.float() * samples_loss\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n        else:\n            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction=\"none\")\n            samples_loss = y_mask.float() * samples_loss\n            loss = samples_loss.mean()\n            samples_loss = samples_loss.detach()\n\n        results = {\"loss\": loss, }\n\n        return results\n\n    def training_epoch_end(self, outputs):\n        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()\n\n        logs = {'train.loss': avg_loss, 'step': self.current_epoch}\n\n        self.log_dict(logs, sync_dist=True)\n\n    def predict(self, batch, batch_idx: int, dataloader_idx: int = None):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n\n        y_hat, _ = self.forward(x)\n        return f, y_hat\n\n    def validation_step(self, batch, batch_idx):\n        x, f, y = batch\n        if self.mel:\n            x = self.mel_forward(x)\n        y_mask = y[:, 20:]\n        y = y[:, :20] > 0.5\n        y = y.float()\n        results = {}\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            y_hat, _ = net(x)\n            samples_loss = F.binary_cross_entropy_with_logits(y_hat, y)\n            samples_loss = y_mask.float() * samples_loss\n            loss = samples_loss.mean()\n            out = torch.sigmoid(y_hat.detach())\n            # self.log(\"validation.loss\", loss, prog_bar=True, on_epoch=True, on_step=False)\n            results = {**results, net_name + \"val_loss\": loss, net_name + \"out\": out, net_name + \"target\": y.detach(),\n                       net_name + \"mask\": y_mask.detach()}\n        results = {k: v.cpu() for k, v in results.items()}\n        return results\n\n    def validation_epoch_end(self, outputs):\n        model_name = [(\"\", self.net)]\n        if self.do_swa:\n            model_name = model_name + [(\"swa_\", self.net_swa)]\n        for net_name, net in model_name:\n            avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()\n            out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)\n            target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0)\n            mask = torch.cat([x[net_name + 'mask'] for x in outputs], dim=0)\n            try:\n                y_true = target.float().numpy()\n                y_pred = out.float().numpy()\n                y_mask = mask.float().numpy()\n                average_precision = np.array([metrics.average_precision_score(\n                    y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])\n            except ValueError:\n                average_precision = np.array([np.nan] * y_true.shape[1])\n            #torch.save(average_precision, f\"ap_openmic_perclass_{average_precision.mean()}.pt\")\n            try:\n                roc = np.array([metrics.roc_auc_score(\n                    y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])\n            except ValueError:\n                roc = np.array([np.nan] * y_true.shape[1])\n            logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),\n                    net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),\n                    net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),\n                    'step': torch.as_tensor(self.current_epoch).cuda()}\n            self.log_dict(logs, sync_dist=True)\n            if self.distributed_mode:\n                allout = self.all_gather(out)\n                alltarget = self.all_gather(target)\n                all_mask = self.all_gather(mask)\n                y_true = alltarget.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()\n                y_pred = allout.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()\n                y_mask = all_mask.reshape(-1, alltarget.shape[-1]).cpu().float().numpy()\n                average_precision = np.array([metrics.average_precision_score(\n                    y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])\n                if self.trainer.is_global_zero:\n                    logs = {net_name + \"allap\": torch.as_tensor(average_precision.mean()).cuda(),\n                            'step': torch.as_tensor(self.current_epoch).cuda()}\n                    self.log_dict(logs, sync_dist=False)\n            else:\n                self.log_dict({net_name + \"allap\": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True)\n\n    def configure_optimizers(self):\n        # REQUIRED\n        # can return multiple optimizers and learning_rate schedulers\n        # (LBFGS it is automatically supported, no need for closure function)\n        optimizer = get_optimizer(self.parameters())\n        # torch.optim.Adam(self.parameters(), lr=self.config.lr)\n        return {\n            'optimizer': optimizer,\n            'lr_scheduler': get_lr_scheduler(optimizer)\n        }\n\n    def configure_callbacks(self):\n        return get_extra_checkpoint_callback() + get_extra_swa_callback()\n\n\n@ex.command\ndef get_dynamic_norm(model, dyn_norm=False):\n    if not dyn_norm:\n        return None, None\n    raise RuntimeError('no dynamic norm supported yet.')\n\n\n@ex.command\ndef get_extra_checkpoint_callback(save_last_n=None):\n    if save_last_n is None:\n        return []\n    return [ModelCheckpoint(monitor=\"step\", verbose=True, save_top_k=save_last_n, mode='max')]\n\n\n@ex.command\ndef get_extra_swa_callback(swa=True, swa_epoch_start=2,\n                           swa_freq=1):\n    if not swa:\n        return []\n    print(\"\\n Using swa!\\n\")\n    from helpers.swa_callback import StochasticWeightAveraging\n    return [StochasticWeightAveraging(swa_epoch_start=swa_epoch_start, swa_freq=swa_freq)]\n\n\n@ex.command\ndef main(_run, _config, _log, _rnd, _seed):\n    \n    trainer = get_trainer(logger=get_logger())\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n    \n    modul = M(ex)\n\n    trainer.fit(\n        modul,\n        train_dataloaders=train_loader,\n        val_dataloaders=val_loader,\n    )\n\n    return {\"done\": True}\n\n\n@ex.command\ndef model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100):\n    '''\n    Test training speed of a model\n    @param _run:\n    @param _config:\n    @param _log:\n    @param _rnd:\n    @param _seed:\n    @param speed_test_batch_size: the batch size during the test\n    @return:\n    '''\n\n    modul = M(ex)\n    modul = modul.cuda()\n    batch_size = speed_test_batch_size\n    print(f\"\\nBATCH SIZE : {batch_size}\\n\")\n    test_length = 100\n    print(f\"\\ntest_length : {test_length}\\n\")\n\n    x = torch.ones([batch_size, 1, 128, 998]).cuda()\n    target = torch.ones([batch_size, 527]).cuda()\n    # one passe\n    net = modul.net\n    # net(x)\n    scaler = torch.cuda.amp.GradScaler()\n    torch.backends.cudnn.benchmark = True\n    # net = torch.jit.trace(net,(x,))\n    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)\n\n    print(\"warmup\")\n    import time\n    torch.cuda.synchronize()\n    t1 = time.time()\n    for i in range(10):\n        with  torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('warmup done:', (t2 - t1))\n    torch.cuda.synchronize()\n    t1 = time.time()\n    print(\"testing speed\")\n\n    for i in range(test_length):\n        with  torch.cuda.amp.autocast():\n            y_hat, embed = net(x)\n            loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction=\"none\").mean()\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n    torch.cuda.synchronize()\n    t2 = time.time()\n    print('test done:', (t2 - t1))\n    print(\"average speed: \", (test_length * batch_size) / (t2 - t1), \" specs/second\")\n\n\n@ex.command\ndef evaluate_only(_run, _config, _log, _rnd, _seed):\n    # force overriding the config, not logged = not recommended\n    trainer = get_trainer()\n    train_loader = get_train_loader()\n    val_loader = get_validate_loader()\n    modul = M(ex)\n    modul.val_dataloader = None\n    #trainer.val_dataloaders = None\n    print(f\"\\n\\nValidation len={len(val_loader)}\\n\")\n    res = trainer.validate(modul, dataloaders=val_loader)\n    print(\"\\n\\n Validtaion:\")\n    print(res)\n\n\n@ex.command\ndef test_loaders():\n    '''\n    get one sample from each loader for debbuging\n    @return:\n    '''\n    for i, b in enumerate(ex.datasets.training.get_iter()):\n        print(b)\n        break\n\n    for i, b in enumerate(ex.datasets.test.get_iter()):\n        print(b)\n        break\n\n\ndef set_default_json_pickle(obj):\n    if isinstance(obj, set):\n        return list(obj)\n    raise TypeError\n\n\n@ex.command\ndef preload_mp3(all_y=CMD(\"/basedataset.preload_mp3\")):\n    '''\n    read the dataset sequentially, useful if you have a network cache\n    @param all_y: the dataset preload command\n    @return:\n    '''\n    print(all_y.shape)\n\n\ndef multiprocessing_run(rank, word_size):\n    print(\"rank \", rank, os.getpid())\n    print(\"word_size \", word_size)\n    os.environ['NODE_RANK'] = str(rank)\n    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(\",\")[rank]\n    argv = sys.argv\n    if rank != 0:\n        print(f\"Unobserved {os.getpid()} with rank {rank}\")\n        argv = argv + [\"-u\"]  # only rank 0 is observed\n    if \"with\" not in argv:\n        argv = argv + [\"with\"]\n\n    argv = argv + [f\"trainer.num_nodes={word_size}\", f\"trainer.accelerator=ddp\"]\n    print(argv)\n\n    @ex.main\n    def default_command():\n        return main()\n\n    ex.run_commandline(argv)\n\n\nif __name__ == '__main__':\n    # set DDP=2 forks two processes to run on two GPUs\n    # the environment variable \"DDP\" define the number of processes to fork\n    # With two 2x 2080ti you can train the full model to .47 in around 24 hours\n    # you may need to set NCCL_P2P_DISABLE=1\n    word_size = os.environ.get(\"DDP\", None)\n    if word_size:\n        import random\n\n        word_size = int(word_size)\n        print(f\"\\n\\nDDP TRAINING WITH WORD_SIZE={word_size}\\n\\n\")\n        os.environ['MASTER_ADDR'] = '127.0.0.1'\n        os.environ['MASTER_PORT'] = f\"{9999 + random.randint(0, 9999)}\"  # plz no collisions\n        os.environ['PL_IN_DDP_SUBPROCESS'] = '1'\n\n        for rank in range(word_size):\n            pid = os.fork()\n            if pid == 0:\n                print(\"Child Forked \")\n                multiprocessing_run(rank, word_size)\n                exit(0)\n\n        pid, exit_code = os.wait()\n        print(pid, exit_code)\n        exit(0)\n\nprint(\"__main__ is running pid\", os.getpid(), \"in module main: \", __name__)\n\n\n@ex.automain\ndef default_command():\n    return main()\n"
  },
  {
    "path": "fsd50k/README.md",
    "content": "# Experiments on FSD50K\n The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432))  consists of 51K audio clips annotated\nwith 200 sound event classes taken from the Audioset ontology. The dataset contains 100 hours of audio and is the\nsecond largest publicly available general purpose sound event\nrecognition dataset after Audioset. Furthermore, the FSD50K\nevaluation set is of high quality, with each evaluation label being double-checked and assessed by two to five independent annotators \n\n# Setup\n1. Download the dataset from [Zenodo](https://zenodo.org/record/4060432) and unzip it.\n2. Convert wav files to mp3s:\n```shell\ncd fsd50k/prepare_scripts/\n\npython convert_to_mp3.py path/to/fsd50k\n ```\nthis will create a folder inside the FSD50K directory with the mp3 files.\n3. Pack the mp3 to HDF5 files:\n```shell\ncd fsd50k/prepare_scripts/\npython create_h5pymp3_dataset.py path/to/fsd50k\n ```\nNow you should have inside `../../audioset_hdf5s/mp3/` three new files: `FSD50K.eval_mp3.hdf`, `FSD50K.val_mp3.hdf`, `FSD50K.train_mp3.hdf`.\n\n\n# Runing Experiments\n\nSimilar to the runs on Audioset, PaSST-S:\n\n```shell\n# Example call with all the default config:\npython ex_fsd50k.py with  trainer.precision=16  -p\n```\n\n```shell\n# Example call without overlap:\npython ex_fsd50k.py with  passt_s_swa_p16_s16_128_ap473 models.net.s_patchout_t=10  models.net.s_patchout_f=1 trainer.precision=16  -p\n```\n\n\n# Pre-trained models\n\nPre-trained models on FSD50K can be found here [here](https://github.com/kkoutini/PaSST/releases/tag/v0.0.5). \n\nIn order to use the pre-trained models, for fine-tuning or inference, using a minimal dependencies, refer to the [PaSST-HEAR](https://github.com/kkoutini/passt_hear21), as an example after installing passt_hear21 :\n\n```python\nfrom hear21passt.base import get_basic_model,get_model_passt\nimport torch\n# model wrapper, includes Melspectrogram and the default pre-trained transformer\nmodel = get_basic_model(mode=\"logits\")\n# replace the transformer with one that outputs 200 classes\nmodel.net = get_model_passt(arch=\"passt_s_swa_p16_128_ap476\",  n_classes=200)\n\n# load the pre-trained model state dict with mAP of .655 on FSD50K\nstate_dict = torch.load('/home/khaled/fsd50k-passt-s-f128-p16-s10-ap.655.pt')\n# load the weights into the transformer\nmodel.net.load_state_dict(state_dict)\n\n# example inference\nmodel.eval()\nmodel = model.cuda()\nwith torch.no_grad():\n    # audio_wave has the shape of [batch, seconds*32000] sampling rate is 32k\n    logits=model(audio_wave) \n```\n\n\nUsing the model with no patch overlap PaSST-S-N `fsd50k-passt-s-n-f128-p16-s16-ap.642.pt`:\n```python\n# replace the transformer with one that outputs 200 classes\nmodel.net = get_model_passt(arch=\"passt_s_p16_s16_128_ap468\", fstride=16,\n                                     tstride=16,  n_classes=200)\n\n# load the pre-trained model state dict with mAP of .642 on FSD50K with no patch overlap\nstate_dict = torch.load('/home/khaled/fsd50k-passt-s-n-f128-p16-s16-ap.642.pt')\n# load the weights into the transformer\nmodel.net.load_state_dict(state_dict)\n\n```"
  },
  {
    "path": "fsd50k/dataset.py",
    "content": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler\n\nimport torch\nfrom ba3l.ingredients.datasets import Dataset\nfrom sacred.config import DynamicIngredient, CMD\nfrom scipy.signal import convolve\nimport numpy as np\nfrom helpers.audiodatasets import PreprocessDataset\nimport h5py\n\nLMODE = os.environ.get(\"LMODE\", False)\n# $TMPDIR\ndataset = Dataset('audiodataset')\n\n\n@dataset.config\ndef default_config():\n    name = 'audioset'  # dataset name\n    normalize = False  # normalize dataset\n    subsample = False  # subsample squares from the dataset\n    roll = True  # apply roll augmentation\n    fold = 1\n    base_dir = \"audioset_hdf5s/\"  # base directory of the dataset, change it or make a link\n    if LMODE:\n        base_dir = \"/system/user/publicdata/CP/audioset/audioset_hdf5s/\"\n\n    balanced_train_hdf5 = base_dir + \"mp3/FSD50K.train_mp3.hdf\"\n    valid_hdf5 = base_dir + \"mp3/FSD50K.val_mp3.hdf\"\n    eval_hdf5 = base_dir + \"mp3/FSD50K.eval_mp3.hdf\"\n    if LMODE:\n        balanced_train_hdf5 = balanced_train_hdf5.replace(base_dir, os.environ.get(\"TMPDIR\", base_dir) + \"/\")\n        eval_hdf5 = eval_hdf5.replace(base_dir, os.environ.get(\"TMPDIR\", base_dir) + \"/\")\n        valid_hdf5 = valid_hdf5.replace(base_dir, os.environ.get(\"TMPDIR\", base_dir) + \"/\")\n    ir_path = base_dir + \"irs/\"\n    num_of_classes = 200\n\n\nif LMODE:\n    @dataset.config\n    def LMODE_default_config():\n        cache_root_path = \"/system/user/publicdata/CP/DCASE/cached_datasets/\"\n\n\ndef decode_mp3(mp3_arr):\n    \"\"\"\n    decodes an array if uint8 representing an mp3 file\n    :rtype: np.array\n    \"\"\"\n    container = av.open(io.BytesIO(mp3_arr.tobytes()))\n    stream = next(s for s in container.streams if s.type == 'audio')\n    # print(stream)\n    a = []\n    for i, packet in enumerate(container.demux(stream)):\n        for frame in packet.decode():\n            a.append(frame.to_ndarray().reshape(-1))\n    waveform = np.concatenate(a)\n    if waveform.dtype != 'float32':\n        raise RuntimeError(\"Unexpected wave type\")\n    return waveform\n\n\ndef pad_or_truncate(x, audio_length):\n    \"\"\"Pad all audio to specific length.\"\"\"\n    if audio_length is None:\n        # audio_length not specified don't do anything.\n        return x\n    if len(x) <= audio_length:\n        return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)\n    else:\n        offset = torch.randint(0, len(x) - audio_length + 1, (1,)).item()\n        return x[offset:offset + audio_length]\n\n\nirs_arr = None\n\n\n@dataset.command\ndef get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):\n    if not ir_augment:\n        return\n    global irs_arr\n    if irs_arr is None:\n        all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]\n        all_paths = sorted(all_paths)\n        if cut_irs_offset is not None:\n            all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]\n        all_paths_name = [str(p).rsplit(\"/\", 1)[-1] for p in all_paths]\n        print(\"will use these IRs:\")\n        for i in range(len(all_paths_name)):\n            print(i, \": \", all_paths_name[i])\n        _run.info[\"ir_devices\"] = all_paths_name\n        irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]\n    return irs_arr[int(np.random.randint(0, len(irs_arr)))]\n\n\n@dataset.command\ndef pydub_augment(waveform, gain_augment=7, ir_augment=0):\n    if ir_augment and torch.rand(1) < ir_augment:\n        ir = get_ir_sample()\n        waveform = convolve(waveform, ir, 'full')\n    if gain_augment:\n        gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment\n        amp = 10 ** (gain / 20)\n        waveform = waveform * amp\n    return waveform\n\n\nclass MixupDataset(TorchDataset):\n    \"\"\" Mixing Up wave forms\n    \"\"\"\n\n    def __init__(self, dataset, beta=2, rate=0.5):\n        self.beta = beta\n        self.rate = rate\n        self.dataset = dataset\n        print(f\"Mixing up waveforms from dataset of len {len(dataset)}\")\n\n    def __getitem__(self, index):\n        if torch.rand(1) < self.rate:\n            x1, f1, y1 = self.dataset[index]\n            idx2 = torch.randint(len(self.dataset), (1,)).item()\n            x2, f2, y2 = self.dataset[idx2]\n            l = np.random.beta(self.beta, self.beta)\n            l = max(l, 1. - l)\n            x1 = x1 - x1.mean()\n            x2 = x2 - x2.mean()\n            x = (x1 * l + x2 * (1. - l))\n            x = x - x.mean()\n            return x, f1, (y1 * l + y2 * (1. - l))\n        return self.dataset[index]\n\n    def __len__(self):\n        return len(self.dataset)\n\n\nclass AudioSetDataset(TorchDataset):\n    def __init__(self, hdf5_file, sample_rate=32000, classes_num=200, clip_length=10, augment=False, in_mem=False):\n        \"\"\"\n        Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav\n        \"\"\"\n        self.sample_rate = sample_rate\n        self.hdf5_file = hdf5_file\n        if in_mem:\n            print(\"\\nPreloading in memory\\n\")\n            with open(hdf5_file, 'rb') as f:\n                self.hdf5_file = io.BytesIO(f.read())\n        with h5py.File(hdf5_file, 'r') as f:\n            self.length = len(f['audio_name'])\n            print(f\"Dataset from {hdf5_file} with length {self.length}.\")\n        self.dataset_file = None  # lazy init\n        self.clip_length = clip_length\n        if clip_length is not None:\n            self.clip_length = clip_length * sample_rate\n        self.classes_num = classes_num\n        self.augment = augment\n        if augment:\n            print(f\"Will agument data from {hdf5_file}\")\n\n    def open_hdf5(self):\n        self.dataset_file = h5py.File(self.hdf5_file, 'r')\n\n    def __len__(self):\n        return self.length\n\n    def __del__(self):\n        if self.dataset_file is not None:\n            self.dataset_file.close()\n            self.dataset_file = None\n\n    def __getitem__(self, index):\n        \"\"\"Load waveform and target of an audio clip.\n\n        Args:\n          meta: {\n            'hdf5_path': str,\n            'index_in_hdf5': int}\n        Returns:\n          data_dict: {\n            'audio_name': str,\n            'waveform': (clip_samples,),\n            'target': (classes_num,)}\n        \"\"\"\n        if self.dataset_file is None:\n            self.open_hdf5()\n\n        audio_name = self.dataset_file['audio_name'][index].decode()\n        waveform = decode_mp3(self.dataset_file['mp3'][index])\n        if self.augment:\n            waveform = pydub_augment(waveform)\n        waveform = pad_or_truncate(waveform, self.clip_length)\n        waveform = self.resample(waveform)\n        target = self.dataset_file['target'][index]\n        target = np.unpackbits(target, axis=-1,\n                               count=self.classes_num).astype(np.float32)\n        return waveform.reshape(1, -1), audio_name, target\n\n    def resample(self, waveform):\n        \"\"\"Resample.\n        Args:\n          waveform: (clip_samples,)\n        Returns:\n          (resampled_clip_samples,)\n        \"\"\"\n        if self.sample_rate == 32000:\n            return waveform\n        elif self.sample_rate == 16000:\n            return waveform[0:: 2]\n        elif self.sample_rate == 8000:\n            return waveform[0:: 4]\n        else:\n            raise Exception('Incorrect sample rate!')\n\n\n@dataset.command\ndef get_base_training_set(balanced_train_hdf5, clip_length=10):\n    ds = AudioSetDataset(balanced_train_hdf5, augment=True, clip_length=clip_length)\n    return ds\n\n\n@dataset.command\ndef preload_mp3(balanced_train_hdf5, unbalanced_train_hdf5, num_of_classes):\n    # Preload mp3 sequential from disk, OS will cache the chunks in memory.\n    # Useful if the hdf file is on a NFS mount, saving the random access.\n    for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]:\n        print(f\"\\n \\n will now preload {hdf5_file} \\n\\n \")\n        with h5py.File(hdf5_file, 'r') as dataset_file:\n            target = dataset_file['mp3'][:]\n            print(len(target))\n            print(f\"\\n \\n done with  {hdf5_file} \\n\\n \")\n    return target[1000]\n\n\n@dataset.command\ndef get_ft_weighted_sampler(samples_weights=CMD(\".get_ft_cls_balanced_sample_weights\"),\n                            epoch_len=100000, sampler_replace=False):\n    num_nodes = int(os.environ.get('num_nodes', 1))\n    ddp = int(os.environ.get('DDP', 1))\n    num_nodes = max(ddp, num_nodes)\n    print(\"num_nodes= \", num_nodes)\n    rank = int(os.environ.get('NODE_RANK', 0))\n    return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,\n                                                                   num_samples=epoch_len, replacement=sampler_replace),\n                                     dataset=range(epoch_len),\n                                     num_replicas=num_nodes,\n                                     rank=rank,\n                                     )\n\n\n@dataset.command\ndef get_base_eval_set(eval_hdf5, variable_eval=None):\n    if variable_eval:\n        print(\"Variable length eval!!\")\n        ds = AudioSetDataset(eval_hdf5, clip_length=None)\n    else:\n        ds = AudioSetDataset(eval_hdf5)\n    return ds\n\n\n@dataset.command\ndef get_base_valid_set(valid_hdf5, variable_eval=None):\n    if variable_eval:\n        print(\"Variable length valid_set !!\")\n        ds = AudioSetDataset(valid_hdf5, clip_length=None)\n    else:\n        ds = AudioSetDataset(valid_hdf5)\n    return ds\n\n\n@dataset.command(prefix='roll_conf')\ndef get_roll_func(axis=1, shift=None, shift_range=50):\n    print(\"rolling...\")\n\n    def roll_func(b):\n        x, i, y = b\n        x = torch.as_tensor(x)\n        sf = shift\n        if shift is None:\n            sf = int(np.random.random_integers(-shift_range, shift_range))\n        global FirstTime\n\n        return x.roll(sf, axis), i, y\n\n    return roll_func\n\n\n@dataset.command\ndef get_training_set(normalize, roll, wavmix=False):\n    ds = get_base_training_set()\n    get_ir_sample()\n    if normalize:\n        print(\"normalized train!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    if roll:\n        ds = PreprocessDataset(ds, get_roll_func())\n    if wavmix:\n        ds = MixupDataset(ds)\n\n    return ds\n\n\n@dataset.command\ndef get_valid_set(normalize):\n    ds = get_base_valid_set()\n    if normalize:\n        print(\"normalized test!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    return ds\n\n\n@dataset.command\ndef get_eval_set(normalize):\n    ds = get_base_eval_set()\n    if normalize:\n        print(\"normalized test!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    return ds\n\n\n@dataset.command\ndef print_conf(_config):\n    print(\"Config of \", dataset.path, id(dataset))\n    print(_config)\n    print()\n\n\nclass DistributedSamplerWrapper(DistributedSampler):\n    def __init__(\n            self, sampler, dataset,\n            num_replicas=None,\n            rank=None,\n            shuffle: bool = True):\n        super(DistributedSamplerWrapper, self).__init__(\n            dataset, num_replicas, rank, shuffle)\n        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238\n        self.sampler = sampler\n\n    def __iter__(self):\n        if self.sampler.generator is None:\n            self.sampler.generator = torch.Generator()\n        self.sampler.generator.manual_seed(self.seed + self.epoch)\n        indices = list(self.sampler)\n        if self.epoch == 0:\n            print(f\"\\n DistributedSamplerWrapper :  {indices[:10]} \\n\\n\")\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        return iter(indices)\n\n\nif __name__ == \"__main__\":\n    from sacred import Experiment\n\n    ex = Experiment(\"test_dataset\", ingredients=[dataset])\n\n\n    @ex.automain\n    def default_command():\n        ex.current_run.get_command_function(\"print_config\")()\n        get_base_training_set()\n        ds = get_test_set()\n        print(ds[0])\n        ds = get_training_set()\n        print(ds[0])\n        print(\"get_base_training_set\", len(get_base_training_set()))\n        print(\"get_base_test_set\", len(get_base_test_set()))\n        print(\"get_training_set\", len(get_training_set()))\n        print(\"get_test_set\", len(get_test_set()))\n"
  },
  {
    "path": "fsd50k/prepare_scripts/convert_to_mp3.py",
    "content": "import multiprocessing\nimport glob\nimport os\nimport sys\n\nif len(sys.argv) > 1:\n    FSD50K_base = sys.argv[1] # the path to of FSD50K base as downloaded from zalando.\nelse:\n    FSD50K_base = \"/home/khaled/shared/FSD50K/\"  # the path to of FSD50K base as downloaded from zalando.\n    print(\"Pass the path to FSD50K: python convert_to_mp3.py path/to/fsd50k\")\n\n\n\noutputp = FSD50K_base + \"/mp3/\"  # the path to the output mp3.\n\nall_num = 0\n\n\ndef process_folder(fol=\"balanced_train_segments\"):\n    print(\"now working on \", fol)\n    os.makedirs(outputp + fol, exist_ok=True)\n    all_files = list(glob.glob(FSD50K_base + fol + \"/*.wav\"))\n    print(f\"it has {len(all_files)}\")\n    global all_num\n    all_num = len(all_files)\n    cmds = [(i, file, outputp + fol + \"/\" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]\n    print(cmds[0])\n    with multiprocessing.Pool(processes=20) as pool:\n        pool.starmap(process_one, cmds)\n\n\ndef process_one(i, f1, f2):\n    if i % 100 == 0:\n        print(f\"{i}/{all_num} \\t\", f1)\n    os.system(f\"ffmpeg  -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3\")\n\n\nprint(\"We will convert the following folders to mp3: \")\nfolders = ['FSD50K.eval_audio', 'FSD50K.dev_audio']\n\nprint(folders)\n\nfor fol in folders:\n    process_folder(fol)\n"
  },
  {
    "path": "fsd50k/prepare_scripts/create_h5pymp3_dataset.py",
    "content": "# %%\nimport sys\n\nimport h5py\nimport pandas as pd\nimport numpy as np\nimport csv\nimport os\n\n# %%\nfrom numpy import dtype\n\nif len(sys.argv) > 1:\n    FSD50K_base = sys.argv[1] # the path to of FSD50K base as downloaded from zalando.\nelse:\n    FSD50K_base = \"/home/khaled/shared/FSD50K/\"  # the path to of FSD50K base as downloaded from zalando.\n    print(\"Pass the path to FSD50K: python convert_to_mp3.py path/to/fsd50k\")\n\nbase_dir = \"../../audioset_hdf5s/\" # the path to store hdf file.\n\n\n\n####\nbalanced_csv = FSD50K_base + \"FSD50K.ground_truth/dev.csv\"\neval_csv = FSD50K_base + \"FSD50K.ground_truth/eval.csv\"\nclass_idx_csv = FSD50K_base + \"FSD50K.ground_truth/vocabulary.csv\"\nmp3_path = \"/home/khaled/shared/FSD50K/mp3/\"\n\n# %%\n\ndf = pd.read_csv(class_idx_csv, header=None, index_col=0)\nclasses_list = list(df[1].values)\nassert sorted(classes_list) == classes_list\nid_to_ix = {id: i for i, id in enumerate(classes_list)}\nix_to_id = {i: id for i, id in enumerate(classes_list)}\n\n# %%\n\n# Load labels\ndf = pd.read_csv(balanced_csv)\n\ntrain = df[df.split == \"train\"]\nval = df[df.split == \"val\"]\n\neval = pd.read_csv(eval_csv)\n\n\n# %%\ndef get_labels(df):\n    y = np.zeros((len(df), 200), dtype=np.int32)\n\n    for i, target in enumerate(df.labels.values):\n        for t in target.split(\",\"):\n            y[i, id_to_ix[t]] = 1\n    return df.fname.values, y\n\n\n# %%\n\nfor set_name, df, prefix in [(\"train\", train, \"FSD50K.dev_audio/\"), (\"val\", val, \"FSD50K.dev_audio/\"),\n                             (\"eval\", eval, \"FSD50K.eval_audio/\")]:\n    print(\"now working on \", set_name, prefix, \"len=\", len(df))\n    # files, y = torch.load(read_file+\".pth\")\n    files, y = get_labels(df)\n    y = np.packbits(y, axis=-1)\n    packed_len = y.shape[1]\n    print(files[0], \"classes: \", packed_len, y.dtype)\n    available_size = len(files)\n    dt = h5py.vlen_dtype(np.dtype('uint8'))\n    save_file = \"FSD50K.\" + set_name\n    if os.path.isfile(base_dir + \"mp3/\" + save_file + \"_mp3.hdf\"):\n        print(base_dir + \"mp3/\" + save_file + \"_mp3.hdf\", \"exists!\\n\\n\\n contiue\")\n        continue\n    with h5py.File(base_dir + \"mp3/\" + save_file + \"_mp3.hdf\", 'w') as hf:\n        audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')\n        waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)\n        target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)\n        for i, file in enumerate(files):\n            if i % 1000 == 0:\n                print(f\"{i}/{available_size}\")\n            f = f\"{file}.mp3\"\n            a = np.fromfile(mp3_path + prefix + f, dtype='uint8')\n            audio_name[i] = f\n            waveform[i] = a\n            target[i] = y[i]\n\n    print(a.shape)\n    print(\"Done!\", prefix)\n"
  },
  {
    "path": "helpers/audiodatasets.py",
    "content": "import hashlib\nimport os\nimport time\n\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom os.path import expanduser\n\nimport logging\n\ndef h6(w):\n    return hashlib.md5(w.encode('utf-8')).hexdigest()[:6]\n\nclass AudioPreprocessDataset(Dataset):\n    \"\"\"A bases preprocessing dataset representing a Dataset of files that are loaded and preprossessed on the fly.\n\n    Access elements via __getitem__ to return: preprocessor(x),sample_id,label\n\n    supporting integer indexing in range from 0 to len(self) exclusive.\n    \"\"\"\n\n    def __init__(self, files, labels, label_encoder, base_dir, preprocessor, return_tensor=True, ordered_ids=None):\n        self.files = files\n        if ordered_ids is None:\n            ordered_ids = files\n        else:\n            print(\"AudioPreprocessDataset: ordered_ids is not None using it instead of files !!!\")\n        self.ordered_ids = ordered_ids\n        self.labels = labels\n        self.label_encoder = label_encoder\n        self.base_dir = base_dir\n        self.preprocessor = preprocessor\n        self.return_tensor = return_tensor\n\n    def __getitem__(self, index):\n        x = self.preprocessor(self.base_dir + self.files[index])\n        if self.return_tensor and not isinstance(x, torch.Tensor):\n            x = torch.from_numpy(x)\n        return x, self.ordered_ids[index], self.labels[index]\n\n    def get_ordered_ids(self):\n        return self.ordered_ids\n\n    def get_ordered_labels(self):\n        return self.labels\n\n    def __len__(self):\n        return len(self.ordered_ids)\n\nclass ObjectCacher:\n    def __init__(self, get_obj_func, dataset_name, obj_name=\"\",\n                 cache_path=\"~/shared/kofta_cached_datasets/\", verbose=True):\n        self.dataset_name = dataset_name\n        self.obj_name = obj_name\n        cache_path = expanduser(cache_path)\n        self.cache_path = os.path.join(cache_path, dataset_name)\n        try:\n            startTime = time.time()\n            xpath = self.get_obj_cache_path()\n\n            if verbose:\n                logging.info(\n                    \"attempting to load x from cache at \" + xpath + \"...\")\n            self.obj = torch.load(xpath)\n\n            if verbose:\n                endTime = time.time()\n                logging.info(\n                    \"loaded \" + xpath + \" from cache in %s \" % (endTime - startTime))\n        except IOError:\n            if verbose:\n                logging.info(\n                    \"Invalid cache \" + xpath + \" , recomputing\")\n            self.obj = get_obj_func()\n            saveStartTime = time.time()\n            dirpath=os.path.dirname(xpath)\n            try:\n                original_umask = os.umask(0)\n                os.makedirs(dirpath, exist_ok=True)\n            finally:\n                os.umask(original_umask)\n            torch.save(self.obj, xpath)\n            if verbose:\n                endTime = time.time()\n                logging.info(\n                    \"loaded \" + obj_name + \" in %s, and cached in %s, total %s seconds \" % (\n                        (saveStartTime - startTime),\n                        (endTime - saveStartTime), (endTime - startTime)))\n\n    def get_obj_cache_path(self):\n        return os.path.join(self.cache_path, self.obj_name + \"_obj.pt\")\n\n    def get(self):\n        return self.obj\n\n\n\nclass PreprocessDataset(Dataset):\n    \"\"\"A bases preprocessing dataset representing a preprocessing step of a Dataset preprossessed on the fly.\n\n\n    supporting integer indexing in range from 0 to len(self) exclusive.\n    \"\"\"\n\n    def __init__(self, dataset, preprocessor):\n        self.dataset = dataset\n        if not callable(preprocessor):\n            print(\"preprocessor: \", preprocessor)\n            raise ValueError('preprocessor should be callable')\n        self.preprocessor = preprocessor\n    def __getitem__(self, index):\n        return self.preprocessor(self.dataset[index])\n    def __len__(self):\n        return len(self.dataset)\n\n\n\nclass FilesCachedDataset(Dataset):\n    def __init__(self, get_dataset_func, dataset_name, x_name=\"\",\n                 cache_path=\"~/shared/kofta_cached_datasets/\",\n                 ):\n        \"\"\"\n            Cached the dataset in small torch.save files (1 file per sample).\n            The dataset is suitable for SSDs being used bcache from a slow harddrive with a small\n        @param get_dataset_func: fuction gets called if the file cache is invalid\n        @param dataset_name: the folder containing the dataset\n        @param x_name: tag for the version\n        @param cache_path: cache_path\n        \"\"\"\n        self.dataset = None\n\n        def getDataset():\n            if self.dataset == None:\n                self.dataset = get_dataset_func()\n            return self.dataset\n\n        self.get_dataset_func = getDataset\n        self.x_name = x_name\n        cache_path = expanduser(cache_path)\n        self.cache_path = os.path.join(cache_path, dataset_name, \"files_cache\", self.x_name)\n        try:\n            original_umask = os.umask(0)\n            os.makedirs(self.cache_path, exist_ok=True)\n        finally:\n            os.umask(original_umask)\n\n    def __getitem__(self, index):\n        cpath = os.path.join(self.cache_path, str(index) + \".pt\")\n        try:\n            return torch.load(cpath)\n        except FileNotFoundError:\n            tup = self.get_dataset_func()[index]\n            torch.save(tup, cpath)\n            return tup\n\n    def get_ordered_labels(self):\n        return self.get_dataset_func().get_ordered_labels()\n\n    def get_ordered_ids(self):\n        return self.get_dataset_func().get_ordered_ids()\n\n    def get_xcache_path(self):\n        return os.path.join(self.cache_path, self.x_name + \"_x.pt\")\n\n    def get_ycache_path(self):\n        return os.path.join(self.cache_path, self.y_name + \"_y.pt\")\n\n    def get_sidcache_path(self):\n        return os.path.join(self.cache_path, self.y_name + \"_sid.pt\")\n\n    def __len__(self):\n        return len(self.get_dataset_func())\n\n\n\nclass SelectionDataset(Dataset):\n    \"\"\"A dataset that selects a subsample from a dataset based on a set of sample ids.\n\n\n        supporting integer indexing in range from 0 to len(self) exclusive.\n    \"\"\"\n\n    def __init__(self, dataset, sample_ids):\n        self.available_indexes = []\n        self.dataset = dataset\n        self.reselect(sample_ids)\n        self.sample_ids = sample_ids\n\n    def reselect(self, sample_ids):\n        reverse_dict = dict([(sid, i) for i, sid in enumerate(self.dataset.get_ordered_ids())])\n        self.available_indexes = [reverse_dict[sid] for sid in sample_ids]\n\n    def get_ordered_ids(self):\n        return self.sample_ids\n\n    def get_ordered_labels(self):\n        labels=self.dataset.get_ordered_labels()\n        return [labels[i] for i in self.available_indexes]\n        #raise NotImplementedError(\"Maybe reconsider caching only a selection Dataset. why not select after cache?\")\n\n    def __getitem__(self, index):\n        return self.dataset[self.available_indexes[index]]\n\n    def __len__(self):\n        return len(self.available_indexes)\n\nclass SimpleSelectionDataset(Dataset):\n    \"\"\"A dataset that selects a subsample from a dataset based on a set of sample ids.\n\n\n        supporting integer indexing in range from 0 to len(self) exclusive.\n    \"\"\"\n\n    def __init__(self, dataset, available_indexes ):\n        self.available_indexes = available_indexes\n        self.dataset = dataset\n\n    def __getitem__(self, index):\n        return self.dataset[self.available_indexes[index]]\n\n    def __len__(self):\n        return len(self.available_indexes)\n\n"
  },
  {
    "path": "helpers/mixup.py",
    "content": "\nimport numpy as np\nimport torch\n\ndef my_mixup(size, alpha):\n    rn_indices = torch.randperm(size)\n    lambd = np.random.beta(alpha, alpha, size).astype(np.float32)\n    lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1)\n    lam = torch.FloatTensor(lambd)\n    # data = data * lam + data2 * (1 - lam)\n    # targets = targets * lam + targets2 * (1 - lam)\n    return rn_indices, lam\n\n"
  },
  {
    "path": "helpers/models_size.py",
    "content": "\n\n\n\n\n\ndef count_non_zero_params(model):\n    sum_params = 0\n    sum_non_zero = 0\n    desc = \"\"\n\n    def calc_params(model):\n        nonlocal desc, sum_params, sum_non_zero\n        skip = \"\"\n        if \"batchnorm\" in type(model).__name__.lower():\n             for k,p in [(\"running_mean\", model.running_mean), (\"running_var\", model.running_var)]:\n                 nonzero = p[p != 0].numel()\n                 total = p.numel()\n                 desc += f\"type {type(model).__name__}, {k},  {total}, {nonzero}, {p.dtype}, {skip} \" + \"\\n\"\n                 if skip != \"skip\":\n                     sum_params += total\n                     sum_non_zero += nonzero\n        for k, p in model.named_parameters(recurse=False):\n            nonzero = p[p != 0].numel()\n            total = p.numel()\n            desc += f\"type {type(model).__name__}, {k},  {total}, {nonzero}, {p.dtype}, {skip} \" + \"\\n\"\n            if skip != \"skip\":\n                sum_params += total\n                sum_non_zero += nonzero\n\n    model.apply(calc_params)\n    return desc, sum_params, sum_non_zero\n\n"
  },
  {
    "path": "helpers/ramp.py",
    "content": "import numpy as np\nfrom ba3l.ingredients.ingredient import Ingredient\n\n\n# credit: https://github.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorch/blob/master/utils/ramps.py\n\n\ndef pseudo_rampup(T1, T2):\n    def warpper(epoch):\n        if epoch > T1:\n            alpha = (epoch - T1) / (T2 - T1)\n            if epoch > T2:\n                alpha = 1.0\n        else:\n            alpha = 0.0\n        return alpha\n\n    return warpper\n\n\ndef exp_rampup(rampup_length):\n    \"\"\"Exponential rampup from https://arxiv.org/abs/1610.02242\"\"\"\n    def warpper(epoch):\n        if epoch < rampup_length:\n            epoch = np.clip(epoch, 0.5, rampup_length)\n            phase = 1.0 - epoch / rampup_length\n            return float(np.exp(-5.0 * phase * phase))\n        else:\n            return 1.0\n    return warpper\n\n\ndef linear_rampup(rampup_length):\n    \"\"\"Linear rampup\"\"\"\n\n    def warpper(epoch):\n        if epoch < rampup_length:\n            return epoch / rampup_length\n        else:\n            return 1.0\n\n    return warpper\n\n\ndef linear_rampdown(rampdown_length, start=0, last_value=0):\n    \"\"\"Linear rampup -(start)- (rampdown_length) \\ _(for the rest)  \"\"\"\n    def warpper(epoch):\n        if epoch <= start:\n            return 1.\n        elif epoch - start < rampdown_length:\n            return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length\n        else:\n            return last_value\n    return warpper\n\n\ndef exp_rampdown(rampdown_length, num_epochs):\n    \"\"\"Exponential rampdown from https://arxiv.org/abs/1610.02242\"\"\"\n\n    def warpper(epoch):\n        if epoch >= (num_epochs - rampdown_length):\n            ep = .5 * (epoch - (num_epochs - rampdown_length))\n            return float(np.exp(-(ep * ep) / rampdown_length))\n        else:\n            return 1.0\n\n    return warpper\n\n\ndef cosine_rampdown(rampdown_length, num_epochs):\n    \"\"\"Cosine rampdown from https://arxiv.org/abs/1608.03983\"\"\"\n\n    def warpper(epoch):\n        if epoch >= (num_epochs - rampdown_length):\n            ep = .5 * (epoch - (num_epochs - rampdown_length))\n            return float(.5 * (np.cos(np.pi * ep / rampdown_length) + 1))\n        else:\n            return 1.0\n\n    return warpper\n\n\ndef exp_warmup(rampup_length, rampdown_length, num_epochs):\n    rampup = exp_rampup(rampup_length)\n    rampdown = exp_rampdown(rampdown_length, num_epochs)\n\n    def warpper(epoch):\n        return rampup(epoch) * rampdown(epoch)\n\n    return warpper\n\n\ndef exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value):\n    rampup = exp_rampup(warmup)\n    rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value)\n    def warpper(epoch):\n        return rampup(epoch) * rampdown(epoch)\n    return warpper\n\n\ndef test_warmup():\n    warmup = exp_warmup(20, 100, 150)\n    for ep in range(500):\n        print(warmup(ep))\n\n\ndef test_warmupl():\n    warmup = exp_warmup_linear_down(20, 100, 50, 0.001)\n    for ep in range(500):\n        print(warmup(ep))\n\n\ndef cosine_cycle(cycle_len=20,ramp_down_start=100,last_lr_value=0.01):\n    \"\"\"Cosine rampdown from https://arxiv.org/abs/1608.03983\"\"\"\n    ramp_down_start = cycle_len+ (ramp_down_start-1)//cycle_len*(cycle_len)\n    print(\"adjusted ramp_down_start:\",ramp_down_start)\n    def warpper(epoch):\n        ep =  (epoch+cycle_len//2.)/(1.*cycle_len)\n        if epoch>ramp_down_start:\n            return last_lr_value\n        return float(last_lr_value + (1.-last_lr_value)* .5 * (np.cos(2.*np.pi * ep) + 1))\n    return warpper\n\n\nif __name__ == '__main__':\n    test= exp_warmup_linear_down(20, 100, 50, 150)\n    for i in range(250):\n        print(test(i))"
  },
  {
    "path": "helpers/swa_callback.py",
    "content": "# Adapted from PyTorch Lightning so that it only does the averaging\n# Copyright The PyTorch Lightning team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nr\"\"\"\nStochastic Weight Averaging Callback\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\"\"\"\nfrom copy import deepcopy\nfrom typing import Callable, Optional, Union\n\nimport torch\nfrom torch import nn\n\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.utilities.exceptions import MisconfigurationException\n\n\nfrom torch.optim.swa_utils import SWALR\n\n_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]\n\n\nclass StochasticWeightAveraging(Callback):\n\n    def __init__(\n        self,\n        swa_epoch_start: Union[int, float] = 0.8,\n        swa_freq: Union[int, float] = 3,\n        swa_lrs: Optional[Union[float, list]] = None,\n        annealing_epochs: int = 10,\n        annealing_strategy: str = \"cos\",\n        avg_fn: Optional[_AVG_FN] = None,\n        device: Optional[Union[torch.device, str]] = None,\n    ):\n        r\"\"\"\n\n        Implements the Stochastic Weight Averaging (SWA) Callback to average a model.\n\n        Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to\n        Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii\n        Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson\n        (UAI 2018).\n\n        This documentation is highly inspired by PyTorch's work on SWA.\n        The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package.\n\n        For a SWA explanation, please take a look\n        `here <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`_.\n\n        .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change.\n\n        .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.\n\n        SWA can easily be activated directly from the Trainer as follow:\n\n        .. code-block:: python\n\n            Trainer(stochastic_weight_avg=True)\n\n        Arguments:\n\n            swa_epoch_start: If provided as int, the procedure will start from\n                the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,\n                the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch\n\n            swa_lrs: the learning rate value for all param groups together or separately for each group.\n\n            annealing_epochs: number of epochs in the annealing phase (default: 10)\n\n            annealing_strategy: Specifies the annealing strategy (default: \"cos\"):\n\n                - ``\"cos\"``. For cosine annealing.\n                - ``\"linear\"`` For linear annealing\n\n            avg_fn: the averaging function used to update the parameters;\n                the function must take in the current value of the\n                :class:`AveragedModel` parameter, the current value of :attr:`model`\n                parameter and the number of models already averaged; if None,\n                equally weighted average is used (default: ``None``)\n\n            device: if provided, the averaged model will be stored on the ``device``.\n                When None is provided, it will infer the `device` from ``pl_module``.\n                (default: ``\"cpu\"``)\n\n        \"\"\"\n\n        err_msg = \"swa_epoch_start should be a >0 integer or a float between 0 and 1.\"\n        if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:\n            raise MisconfigurationException(err_msg)\n        if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):\n            raise MisconfigurationException(err_msg)\n\n        wrong_type = not isinstance(swa_lrs, (float, list))\n        wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0\n        wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)\n        if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)):\n            raise MisconfigurationException(\"The `swa_lrs` should be a positive float or a list of positive float.\")\n\n        if avg_fn is not None and not isinstance(avg_fn, Callable):\n            raise MisconfigurationException(\"The `avg_fn` should be callable.\")\n\n        if device is not None and not isinstance(device, (torch.device, str)):\n            raise MisconfigurationException(f\"device is expected to be a torch.device or a str. Found {device}\")\n        self.swa_freq = swa_freq\n        self._swa_epoch_start = swa_epoch_start\n        self._swa_lrs = swa_lrs\n        self._annealing_epochs = annealing_epochs\n        self._annealing_strategy = annealing_strategy\n        self._avg_fn = avg_fn or self.avg_fn\n        self._device = device\n        self._model_contains_batch_norm = None\n        self._average_model = None\n\n    @property\n    def swa_start(self) -> int:\n        return max(self._swa_epoch_start - 1, 0)  # 0-based\n\n    @property\n    def swa_end(self) -> int:\n        return self._max_epochs - 1  # 0-based\n\n    @staticmethod\n    def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):\n        return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())\n\n    def setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', stage:str):\n        # copy the model before moving it to accelerator device.\n        self._average_model = deepcopy(pl_module.net)\n\n    def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n\n        if len(trainer.optimizers) != 1:\n            raise MisconfigurationException(\"SWA currently works with 1 `optimizer`.\")\n\n        if len(trainer.lr_scheduler_configs) > 1:\n            raise MisconfigurationException(\"SWA currently not supported for more than 1 `lr_scheduler`.\")\n\n        if isinstance(self._swa_epoch_start, float):\n            self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)\n\n        self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)\n\n        self._max_epochs = trainer.max_epochs\n        if self._model_contains_batch_norm:\n            print(\"\\n\\n_model_contains_batch_norm\\n\\n\")\n            # virtually increase max_epochs to perform batch norm update on latest epoch.\n            trainer.max_epochs += 1\n\n    def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n        if trainer.current_epoch == self.swa_start:\n            print(f\"\\n\\n SWA START at {trainer.current_epoch}\\n\\n\")\n            # move average model to request device.\n            self._average_model = self._average_model.to(self._device or pl_module.device)\n\n            optimizers = trainer.optimizers\n\n            for param_group in optimizers[0].param_groups:\n                if self._swa_lrs is None:\n                    initial_lr = param_group[\"lr\"]\n\n                elif isinstance(self._swa_lrs, float):\n                    initial_lr = self._swa_lrs\n\n                else:\n                    initial_lr = self._swa_lrs[0]\n\n                param_group[\"initial_lr\"] = initial_lr\n\n            self._swa_lrs = initial_lr\n\n            self._swa_scheduler = SWALR(\n                optimizers[0],\n                swa_lr=initial_lr,\n                anneal_epochs=self._annealing_epochs,\n                anneal_strategy=self._annealing_strategy,\n                last_epoch=trainer.max_epochs if self._annealing_strategy == \"cos\" else -1\n            )\n\n\n            self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)\n            pl_module.net_swa = self._average_model\n        if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):\n            self.update_parameters(self._average_model, pl_module.net, self.n_averaged, self.avg_fn)\n            pl_module.net_swa = self._average_model\n\n\n\n    def on_train_epoch_end(self, trainer: 'pl.Trainer',pl_module: 'pl.LightningModule', *args):\n        trainer.fit_loop._skip_backward = False\n\n\n    def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n        pass\n\n\n    def on_validation_epoch_start(self, trainer, pl_module) -> None:\n        \"\"\"Called when the val epoch begins.\"\"\"\n        if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):\n            pl_module.do_swa= True\n        else:\n            pl_module.do_swa = False\n\n\n    @staticmethod\n    def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'):\n        for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):\n            dst_param.detach().copy_(src_param.to(dst_param.device))\n\n    def reset_batch_norm_and_save_state(self, pl_module):\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154\n        \"\"\"\n        self.momenta = {}\n        for module in pl_module.modules():\n            if not isinstance(module, nn.modules.batchnorm._BatchNorm):\n                continue\n            module.running_mean = torch.zeros_like(\n                module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype\n            )\n            module.running_var = torch.ones_like(\n                module.running_var, device=pl_module.device, dtype=module.running_var.dtype\n            )\n            self.momenta[module] = module.momentum\n            module.momentum = None\n            module.num_batches_tracked *= 0\n\n    def reset_momenta(self):\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165\n        \"\"\"\n        for bn_module in self.momenta.keys():\n            bn_module.momentum = self.momenta[bn_module]\n\n    @staticmethod\n    def update_parameters(\n        average_model, model, n_averaged: torch.LongTensor, avg_fn: _AVG_FN\n    ):\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112\n        \"\"\"\n        for p_swa, p_model in zip(average_model.parameters(), model.parameters()):\n            device = p_swa.device\n            p_swa_ = p_swa.detach()\n            p_model_ = p_model.detach().to(device)\n            src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device))\n            p_swa_.copy_(src)\n        n_averaged += 1\n\n    @staticmethod\n    def avg_fn(\n        averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor\n    ) -> torch.FloatTensor:\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97\n        \"\"\"\n        return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)\n"
  },
  {
    "path": "helpers/swa_legacy.py",
    "content": "# Adapted from PyTorch Lightning so that it only does the averaging\n# Copyright The PyTorch Lightning team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nr\"\"\"\nStochastic Weight Averaging Callback\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\"\"\"\nfrom copy import deepcopy\nfrom typing import Callable, Optional, Union\n\nimport torch\nfrom torch import nn\n\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks.base import Callback\nfrom pytorch_lightning.trainer.optimizers import _get_default_scheduler_config\nfrom pytorch_lightning.utilities import  rank_zero_warn\nfrom pytorch_lightning.utilities.exceptions import MisconfigurationException\n\nfrom torch.optim.swa_utils import SWALR\n\n_AVG_FN = Callable[[torch.Tensor, torch.Tensor,\n                    torch.LongTensor], torch.FloatTensor]\n\n\nclass StochasticWeightAveraging(Callback):\n\n    def __init__(\n        self,\n        swa_epoch_start: Union[int, float] = 0.8,\n        swa_freq: Union[int, float] = 3,\n        swa_lrs: Optional[Union[float, list]] = None,\n        annealing_epochs: int = 10,\n        annealing_strategy: str = \"cos\",\n        avg_fn: Optional[_AVG_FN] = None,\n        device: Optional[Union[torch.device, str]] = None,\n    ):\n        r\"\"\"\n\n        Implements the Stochastic Weight Averaging (SWA) Callback to average a model.\n\n        Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to\n        Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii\n        Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson\n        (UAI 2018).\n\n        This documentation is highly inspired by PyTorch's work on SWA.\n        The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package.\n\n        For a SWA explanation, please take a look\n        `here <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`_.\n\n        .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change.\n\n        .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.\n\n        SWA can easily be activated directly from the Trainer as follow:\n\n        .. code-block:: python\n\n            Trainer(stochastic_weight_avg=True)\n\n        Arguments:\n\n            swa_epoch_start: If provided as int, the procedure will start from\n                the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1,\n                the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch\n\n            swa_lrs: the learning rate value for all param groups together or separately for each group.\n\n            annealing_epochs: number of epochs in the annealing phase (default: 10)\n\n            annealing_strategy: Specifies the annealing strategy (default: \"cos\"):\n\n                - ``\"cos\"``. For cosine annealing.\n                - ``\"linear\"`` For linear annealing\n\n            avg_fn: the averaging function used to update the parameters;\n                the function must take in the current value of the\n                :class:`AveragedModel` parameter, the current value of :attr:`model`\n                parameter and the number of models already averaged; if None,\n                equally weighted average is used (default: ``None``)\n\n            device: if provided, the averaged model will be stored on the ``device``.\n                When None is provided, it will infer the `device` from ``pl_module``.\n                (default: ``\"cpu\"``)\n\n        \"\"\"\n\n        err_msg = \"swa_epoch_start should be a >0 integer or a float between 0 and 1.\"\n        if isinstance(swa_epoch_start, int) and swa_epoch_start < 1:\n            raise MisconfigurationException(err_msg)\n        if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):\n            raise MisconfigurationException(err_msg)\n\n        wrong_type = not isinstance(swa_lrs, (float, list))\n        wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0\n        wrong_list = isinstance(swa_lrs, list) and not all(\n            lr > 0 and isinstance(lr, float) for lr in swa_lrs)\n        if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)):\n            raise MisconfigurationException(\n                \"The `swa_lrs` should be a positive float or a list of positive float.\")\n\n        if avg_fn is not None and not isinstance(avg_fn, Callable):\n            raise MisconfigurationException(\"The `avg_fn` should be callable.\")\n\n        if device is not None and not isinstance(device, (torch.device, str)):\n            raise MisconfigurationException(\n                f\"device is expected to be a torch.device or a str. Found {device}\")\n        self.swa_freq = swa_freq\n        self._swa_epoch_start = swa_epoch_start\n        self._swa_lrs = swa_lrs\n        self._annealing_epochs = annealing_epochs\n        self._annealing_strategy = annealing_strategy\n        self._avg_fn = avg_fn or self.avg_fn\n        self._device = device\n        self._model_contains_batch_norm = None\n        self._average_model = None\n\n    @property\n    def swa_start(self) -> int:\n        return max(self._swa_epoch_start - 1, 0)  # 0-based\n\n    @property\n    def swa_end(self) -> int:\n        return self._max_epochs - 1  # 0-based\n\n    @staticmethod\n    def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):\n        return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())\n\n    def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n        # copy the model before moving it to accelerator device.\n        self._average_model = deepcopy(pl_module.net)\n\n    def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n        optimizers = trainer.optimizers\n        lr_schedulers = trainer.lr_schedulers\n\n        if len(optimizers) != 1:\n            raise MisconfigurationException(\n                \"SWA currently works with 1 `optimizer`.\")\n\n        if len(lr_schedulers) > 1:\n            raise MisconfigurationException(\n                \"SWA currently not supported for more than 1 `lr_scheduler`.\")\n\n        if isinstance(self._swa_epoch_start, float):\n            self._swa_epoch_start = int(\n                trainer.max_epochs * self._swa_epoch_start)\n\n        self._model_contains_batch_norm = self.pl_module_contains_batch_norm(\n            pl_module)\n\n        self._max_epochs = trainer.max_epochs\n        if self._model_contains_batch_norm:\n            print(\"\\n\\n_model_contains_batch_norm\\n\\n\")\n            # virtually increase max_epochs to perform batch norm update on latest epoch.\n            trainer.max_epochs += 1\n\n    def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n        if trainer.current_epoch == self.swa_start:\n            print(f\"\\n\\n SWA START at {trainer.current_epoch}\\n\\n\")\n            # move average model to request device.\n            self._average_model = self._average_model.to(\n                self._device or pl_module.device)\n\n            optimizers = trainer.optimizers\n\n            for param_group in optimizers[0].param_groups:\n                if self._swa_lrs is None:\n                    initial_lr = param_group[\"lr\"]\n\n                elif isinstance(self._swa_lrs, float):\n                    initial_lr = self._swa_lrs\n\n                else:\n                    initial_lr = self._swa_lrs[0]\n\n                param_group[\"initial_lr\"] = initial_lr\n\n            self._swa_lrs = initial_lr\n\n            self._swa_scheduler = SWALR(\n                optimizers[0],\n                swa_lr=initial_lr,\n                anneal_epochs=self._annealing_epochs,\n                anneal_strategy=self._annealing_strategy,\n                last_epoch=trainer.max_epochs if self._annealing_strategy == \"cos\" else -1\n            )\n\n            self.n_averaged = torch.tensor(\n                0, dtype=torch.long, device=pl_module.device)\n            pl_module.net_swa = self._average_model\n        if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):\n            self.update_parameters(self._average_model,\n                                   pl_module.net, self.n_averaged, self.avg_fn)\n            pl_module.net_swa = self._average_model\n\n    def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', *args):\n        trainer.train_loop._skip_backward = False\n\n    def on_train_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):\n        pass\n\n    def on_validation_epoch_start(self, trainer, pl_module) -> None:\n        \"\"\"Called when the val epoch begins.\"\"\"\n        if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ((trainer.current_epoch - self.swa_start) % self.swa_freq == 0):\n            pl_module.do_swa = True\n        else:\n            pl_module.do_swa = False\n\n    @staticmethod\n    def transfer_weights(src_pl_module: 'pl.LightningModule', dst_pl_module: 'pl.LightningModule'):\n        for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):\n            dst_param.detach().copy_(src_param.to(dst_param.device))\n\n    def reset_batch_norm_and_save_state(self, pl_module):\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154\n        \"\"\"\n        self.momenta = {}\n        for module in pl_module.modules():\n            if not isinstance(module, nn.modules.batchnorm._BatchNorm):\n                continue\n            module.running_mean = torch.zeros_like(\n                module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype\n            )\n            module.running_var = torch.ones_like(\n                module.running_var, device=pl_module.device, dtype=module.running_var.dtype\n            )\n            self.momenta[module] = module.momentum\n            module.momentum = None\n            module.num_batches_tracked *= 0\n\n    def reset_momenta(self):\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165\n        \"\"\"\n        for bn_module in self.momenta.keys():\n            bn_module.momentum = self.momenta[bn_module]\n\n    @staticmethod\n    def update_parameters(\n        average_model, model, n_averaged: torch.LongTensor, avg_fn: _AVG_FN\n    ):\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112\n        \"\"\"\n        for p_swa, p_model in zip(average_model.parameters(), model.parameters()):\n            device = p_swa.device\n            p_swa_ = p_swa.detach()\n            p_model_ = p_model.detach().to(device)\n            src = p_model_ if n_averaged == 0 else avg_fn(\n                p_swa_, p_model_, n_averaged.to(device))\n            p_swa_.copy_(src)\n        n_averaged += 1\n\n    @staticmethod\n    def avg_fn(\n        averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor\n    ) -> torch.FloatTensor:\n        \"\"\"\n        Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97\n        \"\"\"\n        return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)\n"
  },
  {
    "path": "helpers/workersinit.py",
    "content": "import torch\nimport numpy as np\nimport random\n\n\ndef worker_init_fn(x):\n    seed = (torch.initial_seed() + x * 1000) % 2 ** 31  # problem with nearly seeded randoms\n\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.manual_seed(seed)\n    return\n"
  },
  {
    "path": "models/helpers/vit_helpers.py",
    "content": "\"\"\"\nAdapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\nCredit to @leo19941227  for remove timm dependencies here : https://github.com/s3prl/passt_hear21/blob/48a0dc1b824641ca59884ced53f5b86053fed141/hear21passt/models/helpers/vit_helpers.py\n\n\"\"\"\nimport math\nimport logging\nimport warnings\nfrom copy import deepcopy\n\nimport torch\nfrom torch import nn\ntry:\n    from timm.models._hub import download_cached_file\nexcept ModuleNotFoundError:\n    from timm.models.hub import download_cached_file\n\n# Global variables for rarely used pretrained checkpoint download progress and hash check.\n# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.\n_DOWNLOAD_PROGRESS = True\n_CHECK_HASH = False\n\n\n_logger = logging.getLogger(__name__)\n\n\ndef adapt_input_conv(in_chans, conv_weight):\n    conv_type = conv_weight.dtype\n    conv_weight = (\n        conv_weight.float()\n    )  # Some weights are in torch.half, ensure it's float for sum on CPU\n    O, I, J, K = conv_weight.shape\n    if in_chans == 1:\n        if I > 3:\n            assert conv_weight.shape[1] % 3 == 0\n            # For models with space2depth stems\n            conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)\n            conv_weight = conv_weight.sum(dim=2, keepdim=False)\n        else:\n            conv_weight = conv_weight.sum(dim=1, keepdim=True)\n    elif in_chans != 3:\n        if I != 3:\n            raise NotImplementedError(\"Weight format not supported by conversion.\")\n        else:\n            # NOTE this strategy should be better than random init, but there could be other combinations of\n            # the original RGB input layer weights that'd work better for specific cases.\n            repeat = int(math.ceil(in_chans / 3))\n            conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]\n            conv_weight *= 3 / float(in_chans)\n    conv_weight = conv_weight.to(conv_type)\n    return conv_weight\n\n\ndef load_pretrained(\n    model,\n    default_cfg=None,\n    num_classes=1000,\n    in_chans=3,\n    filter_fn=None,\n    strict=True,\n    progress=False,\n):\n    \"\"\"Load pretrained checkpoint\n\n    Args:\n        model (nn.Module) : PyTorch model module\n        default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset\n        num_classes (int): num_classes for model\n        in_chans (int): in_chans for model\n        filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)\n        strict (bool): strict load of checkpoint\n        progress (bool): enable progress bar for weight download\n\n    \"\"\"\n    default_cfg = default_cfg or getattr(model, \"default_cfg\", None) or {}\n    pretrained_url = default_cfg.get(\"url\", None)\n\n    if not pretrained_url:\n        _logger.warning(\n            \"No pretrained weights exist for this model. Using random initialization.\"\n        )\n        return\n\n    _logger.info(f\"Loading pretrained weights from url ({pretrained_url})\")\n    pretrained_loc = download_cached_file(\n            pretrained_url,\n            check_hash=_CHECK_HASH,\n            progress=_DOWNLOAD_PROGRESS,\n        )\n\n    state_dict = torch.load(pretrained_loc, map_location=\"cpu\")\n\n    if filter_fn is not None:\n        # for backwards compat with filter fn that take one arg, try one first, the two\n        try:\n            state_dict = filter_fn(state_dict)\n        except TypeError:\n            state_dict = filter_fn(state_dict, model)\n\n    input_convs = default_cfg.get(\"first_conv\", None)\n    if input_convs is not None and in_chans != 3:\n        if isinstance(input_convs, str):\n            input_convs = (input_convs,)\n        for input_conv_name in input_convs:\n            weight_name = input_conv_name + \".weight\"\n            try:\n                state_dict[weight_name] = adapt_input_conv(\n                    in_chans, state_dict[weight_name]\n                )\n                _logger.info(\n                    f\"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)\"\n                )\n            except NotImplementedError as e:\n                del state_dict[weight_name]\n                strict = False\n                _logger.warning(\n                    f\"Unable to convert pretrained {input_conv_name} weights, using random init for this layer.\"\n                )\n\n    classifiers = default_cfg.get(\"classifier\", None)\n    label_offset = default_cfg.get(\"label_offset\", 0)\n    if classifiers is not None:\n        if isinstance(classifiers, str):\n            classifiers = (classifiers,)\n        if num_classes != default_cfg[\"num_classes\"]:\n            for classifier_name in classifiers:\n                # completely discard fully connected if model num_classes doesn't match pretrained weights\n                del state_dict[classifier_name + \".weight\"]\n                del state_dict[classifier_name + \".bias\"]\n            strict = False\n        elif label_offset > 0:\n            for classifier_name in classifiers:\n                # special case for pretrained weights with an extra background class in pretrained weights\n                classifier_weight = state_dict[classifier_name + \".weight\"]\n                state_dict[classifier_name + \".weight\"] = classifier_weight[\n                    label_offset:\n                ]\n                classifier_bias = state_dict[classifier_name + \".bias\"]\n                state_dict[classifier_name + \".bias\"] = classifier_bias[label_offset:]\n\n    model.load_state_dict(state_dict, strict=strict)\n\n\ndef overlay_external_default_cfg(default_cfg, kwargs):\n    \"\"\"Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.\"\"\"\n    external_default_cfg = kwargs.pop(\"external_default_cfg\", None)\n    if external_default_cfg:\n        default_cfg.pop(\"url\", None)  # url should come from external cfg\n        default_cfg.pop(\"hf_hub\", None)  # hf hub id should come from external cfg\n        default_cfg.update(external_default_cfg)\n\n\ndef filter_kwargs(kwargs, names):\n    if not kwargs or not names:\n        return\n    for n in names:\n        kwargs.pop(n, None)\n\n\ndef set_default_kwargs(kwargs, names, default_cfg):\n    for n in names:\n        # for legacy reasons, model __init__args uses img_size + in_chans as separate args while\n        # default_cfg has one input_size=(C, H ,W) entry\n        if n == \"img_size\":\n            input_size = default_cfg.get(\"input_size\", None)\n            if input_size is not None:\n                assert len(input_size) == 3\n                kwargs.setdefault(n, input_size[-2:])\n        elif n == \"in_chans\":\n            input_size = default_cfg.get(\"input_size\", None)\n            if input_size is not None:\n                assert len(input_size) == 3\n                kwargs.setdefault(n, input_size[0])\n        else:\n            default_val = default_cfg.get(n, None)\n            if default_val is not None:\n                kwargs.setdefault(n, default_cfg[n])\n\n\ndef update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):\n    \"\"\"Update the default_cfg and kwargs before passing to model\n\n    FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs\n    could/should be replaced by an improved configuration mechanism\n\n    Args:\n        default_cfg: input default_cfg (updated in-place)\n        kwargs: keyword args passed to model build fn (updated in-place)\n        kwargs_filter: keyword arg keys that must be removed before model __init__\n    \"\"\"\n    # Overlay default cfg values from `external_default_cfg` if it exists in kwargs\n    overlay_external_default_cfg(default_cfg, kwargs)\n    # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)\n    default_kwarg_names = (\"num_classes\", \"global_pool\", \"in_chans\")\n    if default_cfg.get(\"fixed_input_size\", False):\n        # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size\n        default_kwarg_names += (\"img_size\",)\n    set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)\n    # Filter keyword args for task specific model variants (some 'features only' models, etc.)\n    filter_kwargs(kwargs, names=kwargs_filter)\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n\n    \"\"\"\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (\n        x.ndim - 1\n    )  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n\nfrom torch.nn.init import _calculate_fan_in_and_fan_out\n\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.0))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n\n\ndef variance_scaling_(tensor, scale=1.0, mode=\"fan_in\", distribution=\"normal\"):\n    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n    if mode == \"fan_in\":\n        denom = fan_in\n    elif mode == \"fan_out\":\n        denom = fan_out\n    elif mode == \"fan_avg\":\n        denom = (fan_in + fan_out) / 2\n\n    variance = scale / denom\n\n    if distribution == \"truncated_normal\":\n        # constant is stddev of standard normal truncated to (-2, 2)\n        trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)\n    elif distribution == \"normal\":\n        tensor.normal_(std=math.sqrt(variance))\n    elif distribution == \"uniform\":\n        bound = math.sqrt(3 * variance)\n        tensor.uniform_(-bound, bound)\n    else:\n        raise ValueError(f\"invalid distribution {distribution}\")\n\n\ndef lecun_normal_(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"truncated_normal\")\n\n\ndef build_model_with_cfg(\n    model_cls,\n    variant: str,\n    pretrained: bool,\n    default_cfg: dict,\n    model_cfg=None,\n    feature_cfg=None,\n    pretrained_strict: bool = True,\n    pretrained_filter_fn=None,\n    pretrained_custom_load=False,\n    kwargs_filter=None,\n    **kwargs,\n):\n    \"\"\"Build model with specified default_cfg and optional model_cfg\n\n    This helper fn aids in the construction of a model including:\n      * handling default_cfg and associated pretained weight loading\n      * passing through optional model_cfg for models with config based arch spec\n      * features_only model adaptation\n      * pruning config / model adaptation\n\n    Args:\n        model_cls (nn.Module): model class\n        variant (str): model variant name\n        pretrained (bool): load pretrained weights\n        default_cfg (dict): model's default pretrained/task config\n        model_cfg (Optional[Dict]): model's architecture config\n        feature_cfg (Optional[Dict]: feature extraction adapter config\n        pretrained_strict (bool): load pretrained weights strictly\n        pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights\n        pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights\n        kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model\n        **kwargs: model args passed through to model __init__\n    \"\"\"\n    pruned = kwargs.pop(\"pruned\", False)\n    features = False\n    feature_cfg = feature_cfg or {}\n    default_cfg = deepcopy(default_cfg) if default_cfg else {}\n    update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)\n    default_cfg.setdefault(\"architecture\", variant)\n\n    # Setup for feature extraction wrapper done at end of this fn\n    if kwargs.pop(\"features_only\", False):\n        features = True\n        feature_cfg.setdefault(\"out_indices\", (0, 1, 2, 3, 4))\n        if \"out_indices\" in kwargs:\n            feature_cfg[\"out_indices\"] = kwargs.pop(\"out_indices\")\n\n    # Build the model\n    model = (\n        model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)\n    )\n    model.default_cfg = default_cfg\n\n    # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats\n    num_classes_pretrained = (\n        0\n        if features\n        else getattr(model, \"num_classes\", kwargs.get(\"num_classes\", 1000))\n    )\n    if pretrained:\n        assert not pretrained_custom_load, \"URL should not contain npz for PASST models\"\n        load_pretrained(\n            model,\n            num_classes=num_classes_pretrained,\n            in_chans=kwargs.get(\"in_chans\", 3),\n            filter_fn=pretrained_filter_fn,\n            strict=pretrained_strict,\n        )\n    return model"
  },
  {
    "path": "models/passt.py",
    "content": "\"\"\"\nMost of this code comes from the timm  library.\nWe tried to disentangle from the timm library version.\n\nAdapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\n\n\"\"\"\nimport math\nimport logging\nimport warnings\nfrom functools import partial\nimport collections\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom itertools import repeat\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .helpers.vit_helpers import update_default_cfg_and_kwargs, DropPath, trunc_normal_, build_model_with_cfg\n\n_logger = logging.getLogger()\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):\n            return tuple(x)\n        return tuple(repeat(x, n))\n    return parse\n\nto_2tuple = _ntuple(2)\n\n\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,\n        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    # patch models (weights from official Google JAX impl)\n    'vit_tiny_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_tiny_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_small_patch32_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_small_patch32_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_small_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_small_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_base_patch32_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),\n    'vit_base_patch32_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_base_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),\n    'vit_base_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_large_patch32_224': _cfg(\n        url='',  # no official model weights for this combo, only for in21k\n    ),\n    'vit_large_patch32_384': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',\n        input_size=(3, 384, 384), crop_pct=1.0),\n    'vit_large_patch16_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),\n    'vit_large_patch16_384': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/'\n            'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',\n        input_size=(3, 384, 384), crop_pct=1.0),\n\n    # patch models, imagenet21k (weights from official Google JAX impl)\n    'vit_tiny_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_small_patch32_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_small_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_base_patch32_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_base_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',\n        num_classes=21843),\n    'vit_large_patch32_224_in21k': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',\n        num_classes=21843),\n    'vit_large_patch16_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',\n        num_classes=21843),\n    'vit_huge_patch14_224_in21k': _cfg(\n        url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',\n        hf_hub='timm/vit_huge_patch14_224_in21k',\n        num_classes=21843),\n\n    # SAM trained models (https://arxiv.org/abs/2106.01548)\n    'vit_base_patch32_sam_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),\n    'vit_base_patch16_sam_224': _cfg(\n        url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),\n\n    # deit models (FB weights)\n    'deit_tiny_patch16_224': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n    'deit_small_patch16_224': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n    'deit_base_patch16_224': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),\n    'deit_base_patch16_384': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),\n    'deit_tiny_distilled_patch16_224': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),\n    'deit_small_distilled_patch16_224': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),\n    'deit_base_distilled_patch16_224': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),\n    'deit_base_distilled_patch16_384': _cfg(\n        url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,\n        classifier=('head', 'head_dist')),\n\n    # ViT ImageNet-21K-P pretraining by MILL\n    'vit_base_patch16_224_miil_in21k': _cfg(\n        url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',\n        mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,\n    ),\n    'vit_base_patch16_224_miil': _cfg(\n        url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'\n            '/vit_base_patch16_224_1k_miil_84_4.pth',\n        mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',\n    ),\n    # PaSST\n    'passt_s_swa_p16_128_ap476': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_kd_p16_128_ap486': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_l_kd_p16_128_ap47': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_swa_p16_128_ap4761': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_p16_128_ap472': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.472.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_p16_s16_128_ap468': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_swa_p16_s16_128_ap473': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.473-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_swa_p16_s14_128_ap471': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.471-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_p16_s14_128_ap469': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s14-ap.469.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_swa_p16_s12_128_ap473': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.473-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_p16_s12_128_ap470': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s12-ap.470.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_swa_f128_stfthop100_p16_s10_ap473': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop100-p16-s10-ap.473-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt_s_swa_f128_stfthop160_p16_s10_ap473': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-20sec-p16-s10-ap.474-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.5/passt-s-f128-30sec-p16-s10-ap.473-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=527),\n    'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=20),\n    'openmic2008_passt_u_f128_p16_s10_ap85  ': _cfg(\n        url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85.pt',\n        mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,\n        classifier=('head.1', 'head_dist'), num_classes=20),\n}\n\n\ndef adapt_input_conv(in_chans, conv_weight):\n    conv_type = conv_weight.dtype\n    conv_weight = conv_weight.float()  # Some weights are in torch.half, ensure it's float for sum on CPU\n    O, I, J, K = conv_weight.shape\n    if in_chans == 1:\n        if I > 3:\n            assert conv_weight.shape[1] % 3 == 0\n            # For models with space2depth stems\n            conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)\n            conv_weight = conv_weight.sum(dim=2, keepdim=False)\n        else:\n            conv_weight = conv_weight.sum(dim=1, keepdim=True)\n    elif in_chans != 3:\n        if I != 3:\n            raise NotImplementedError('Weight format not supported by conversion.')\n        else:\n            # NOTE this strategy should be better than random init, but there could be other combinations of\n            # the original RGB input layer weights that'd work better for specific cases.\n            repeat = int(math.ceil(in_chans / 3))\n            conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]\n            conv_weight *= (3 / float(in_chans))\n    conv_weight = conv_weight.to(conv_type)\n    return conv_weight\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nfirst_RUN = True\n\nPLUS1_TRICK = False\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None,\n                 flatten=True):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        stride = to_2tuple(stride)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.stride = stride\n        self.grid_size = (img_size[0] // stride[0], img_size[1] // stride[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n        self.embed_dim = embed_dim\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        if not (H == self.img_size[0] and W == self.img_size[1]):\n            warnings.warn(f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\")\n        # to do maybe replace weights\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC\n        x = self.norm(x)\n        if first_RUN: print(\"self.norm(x)\", x.size())\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        if PLUS1_TRICK:\n            # +1 trick\n            attn = torch.cat([attn, torch.zeros(attn.shape[:-1]+(1,), dtype=attn.dtype, device=attn.device)], dim=-1)\n        attn = attn.softmax(dim=-1)\n        if PLUS1_TRICK:\n            # +1 trick\n            attn = attn[...,:-1]\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)\n        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PaSST(nn.Module):\n    \"\"\"\n\n    Based on the implementation of Vision Transformer in timm library.\n     Take a look at the get_model function, adapting the weights of pretrained imagenet models.\n\n    \"\"\"\n\n    def __init__(self, u_patchout=0, s_patchout_t=0, s_patchout_f=0, img_size=(128, 998), patch_size=16, stride=16,\n                 in_chans=1, num_classes=527, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,\n                 act_layer=None, weight_init=''):\n        \"\"\"\n        Args:\n            u_patchout: Unstructured Patchout integer, number of items to be removed from the final sequence\n            s_patchout_t: structured Patchout time integer, number of columns to be removed from the patches grid\n            s_patchout_f: structured Patchout Frequency integer, number of rows to be removed from the patches grid\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            num_classes (int): number of classes for classification head\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set\n            distilled (bool): model includes a distillation token and head as in DeiT models\n            drop_rate (float): dropout rate\n            attn_drop_rate (float): attention dropout rate\n            drop_path_rate (float): stochastic depth rate\n            embed_layer (nn.Module): patch embedding layer\n            norm_layer: (nn.Module): normalization layer\n            weight_init: (str): weight init scheme\n        \"\"\"\n        super().__init__()\n        self.num_classes = num_classes\n        self.u_patchout = u_patchout\n        self.s_patchout_t = s_patchout_t\n        self.s_patchout_f = s_patchout_f\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_tokens = 2 if distilled else 1\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        act_layer = act_layer or nn.GELU\n\n        self.patch_embed = embed_layer(\n            img_size=img_size, patch_size=patch_size, stride=stride, in_chans=in_chans, embed_dim=embed_dim,\n            flatten=False)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None\n        # PaSST\n        # refer to https://arxiv.org/abs/2110.05069 Section 2\n        self.new_pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))  # for C and D tokens\n        self.freq_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.patch_embed.grid_size[0], 1))  # | f\n        self.time_new_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 1, self.patch_embed.grid_size[1]))  # __ t\n        ####\n        self.pos_drop = nn.Dropout(p=drop_rate)\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n        self.blocks = nn.Sequential(*[\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,\n                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)\n            for i in range(depth)])\n        self.norm = norm_layer(embed_dim)\n\n        # Representation layer\n        if representation_size and not distilled:\n            self.num_features = representation_size\n            self.pre_logits = nn.Sequential(OrderedDict([\n                ('fc', nn.Linear(embed_dim, representation_size)),\n                ('act', nn.Tanh())\n            ]))\n        else:\n            self.pre_logits = nn.Identity()\n\n        # Classifier head(s)\n        self.head = nn.Sequential(nn.LayerNorm(self.num_features),\n                                  nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())\n        self.head_dist = None\n        if distilled:\n            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()\n\n        self.init_weights(weight_init)\n\n    def init_weights(self, mode=''):\n        assert mode in ('jax', 'jax_nlhb', 'nlhb', '')\n        head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.\n        trunc_normal_(self.new_pos_embed, std=.02)\n        trunc_normal_(self.freq_new_pos_embed, std=.02)\n        trunc_normal_(self.time_new_pos_embed, std=.02)\n        if self.dist_token is not None:\n            trunc_normal_(self.dist_token, std=.02)\n        if mode.startswith('jax'):\n            # leave cls token as zeros to match jax impl\n            raise RuntimeError(\"Not supported yet\")\n        else:\n            trunc_normal_(self.cls_token, std=.02)\n            self.apply(_init_vit_weights)\n\n    def _init_weights(self, m):\n        # this fn left here for compat with downstream users\n        _init_vit_weights(m)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'new_pos_embed', 'freq_new_pos_embed', 'time_new_pos_embed', 'cls_token', 'dist_token'}\n\n    def get_classifier(self):\n        if self.dist_token is None:\n            return self.head\n        else:\n            return self.head, self.head_dist\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n        if self.num_tokens == 2:\n            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x):\n        global first_RUN  # not jit friendly? use trace instead\n        x = self.patch_embed(x)  # [b, e, f, t]\n        B_dim, E_dim, F_dim, T_dim = x.shape  # slow\n        if first_RUN: print(\" patch_embed : \", x.shape)\n        # Adding Time/Freq information\n        if first_RUN: print(\" self.time_new_pos_embed.shape\", self.time_new_pos_embed.shape)\n        time_new_pos_embed = self.time_new_pos_embed\n        if x.shape[-1] < time_new_pos_embed.shape[-1]:\n            if self.training:\n                toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item()\n                if first_RUN: print(f\" CUT with randomoffset={toffset} time_new_pos_embed.shape\",\n                                    time_new_pos_embed.shape)\n                time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]]\n            else:\n                time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]\n            if first_RUN: print(\" CUT time_new_pos_embed.shape\", time_new_pos_embed.shape)\n        else:\n            warnings.warn(\n                f\"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut\")\n            x = x[:, :, :, :time_new_pos_embed.shape[-1]]\n        x = x + time_new_pos_embed\n        if first_RUN: print(\" self.freq_new_pos_embed.shape\", self.freq_new_pos_embed.shape)\n        x = x + self.freq_new_pos_embed\n\n        # Structured Patchout https://arxiv.org/abs/2110.05069 Section 2.2\n        if self.training and self.s_patchout_t:\n            if first_RUN: print(f\"X Before time Patchout of {self.s_patchout_t} \", x.size())\n            # ([1, 768, 1, 82])\n            random_indices = torch.randperm(T_dim)[:T_dim - self.s_patchout_t].sort().values\n            x = x[:, :, :, random_indices]\n            if first_RUN: print(\"X after time Patchout\", x.size())\n        if self.training and self.s_patchout_f:\n            if first_RUN: print(f\"X Before Freq Patchout of {self.s_patchout_f} \", x.size())\n            # [1, 768, 12, 1]\n            random_indices = torch.randperm(F_dim)[:F_dim - self.s_patchout_f].sort().values\n            x = x[:, :, random_indices, :]\n            if first_RUN: print(\" \\n X after freq Patchout: \", x.size())\n        ###\n        # Flatten the sequence\n        x = x.flatten(2).transpose(1, 2)\n        # Unstructured Patchout\n        if first_RUN: print(\"X flattened\", x.size())\n        if self.training and self.u_patchout:\n            seq_len = x.shape[1]\n            random_indices = torch.randperm(seq_len)[:seq_len - self.u_patchout].sort().values\n            x = x[:, random_indices, :]\n            if first_RUN: print(\"X After Unstructured Patchout\", x.size())\n        ####\n        # Add the C/D tokens\n        if first_RUN: print(\" self.new_pos_embed.shape\", self.new_pos_embed.shape)\n        cls_tokens = self.cls_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, :1, :]\n        if first_RUN: print(\" self.cls_tokens.shape\", cls_tokens.shape)\n        if self.dist_token is None:\n            x = torch.cat((cls_tokens, x), dim=1)\n        else:\n            dist_token = self.dist_token.expand(B_dim, -1, -1) + self.new_pos_embed[:, 1:, :]\n            if first_RUN: print(\" self.dist_token.shape\", dist_token.shape)\n            x = torch.cat((cls_tokens, dist_token, x), dim=1)\n\n        if first_RUN: print(\" final sequence x\", x.shape)\n        x = self.pos_drop(x)\n        x = self.blocks(x)\n        if first_RUN: print(f\" after {len(self.blocks)} atten blocks x\", x.shape)\n        x = self.norm(x)\n        if self.dist_token is None:\n            return self.pre_logits(x[:, 0])\n        else:\n            return x[:, 0], x[:, 1]\n\n    def forward(self, x):\n        global first_RUN\n        if first_RUN: print(\"x\", x.size())\n\n        x = self.forward_features(x)\n\n        if self.head_dist is not None:\n            features = (x[0] + x[1]) / 2\n            if first_RUN: print(\"forward_features\", features.size())\n            x = self.head(features)\n            if first_RUN: print(\"head\", x.size())\n            first_RUN = False\n            return x, features\n        else:\n            features = x\n            if first_RUN: print(\"forward_features\", features.size())\n            x = self.head(x)\n        if first_RUN: print(\"head\", x.size())\n        first_RUN = False\n        return x, features\n\n\ndef _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):\n    \"\"\" ViT weight initialization\n    * When called without n, head_bias, jax_impl args it will behave exactly the same\n      as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).\n    * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl\n    \"\"\"\n    if isinstance(module, nn.Linear):\n        if name.startswith('head'):\n            nn.init.zeros_(module.weight)\n            nn.init.constant_(module.bias, head_bias)\n        elif name.startswith('pre_logits'):\n            lecun_normal_(module.weight)\n            nn.init.zeros_(module.bias)\n        else:\n            if jax_impl:\n                nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    if 'mlp' in name:\n                        nn.init.normal_(module.bias, std=1e-6)\n                    else:\n                        nn.init.zeros_(module.bias)\n            else:\n                trunc_normal_(module.weight, std=.02)\n                if module.bias is not None:\n                    nn.init.zeros_(module.bias)\n    elif jax_impl and isinstance(module, nn.Conv2d):\n        # NOTE conv was left to pytorch default in my original init\n        lecun_normal_(module.weight)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):\n        nn.init.zeros_(module.bias)\n        nn.init.ones_(module.weight)\n\n\ndef resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), mode='bicubic'):\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    _logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, posemb_new.shape,\n                 num_tokens)\n    ntok_new = posemb_new.shape[1]\n    if num_tokens:\n        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]\n        ntok_new -= num_tokens\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))\n    if not len(gs_new):  # backwards compatibility\n        gs_new = [int(math.sqrt(ntok_new))] * 2\n    assert len(gs_new) >= 2\n    _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n    return posemb\n\n\ndef adapt_image_pos_embed_to_passt(posemb, num_tokens=1, gs_new=(), mode='bicubic'):\n    # Rescale the grid of position embeddings when loading from state_dict. Adapted from\n    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224\n    _logger.info('Resized position embedding: %s to %s with %s cls/dis tokens', posemb.shape, gs_new,\n                 num_tokens)\n    if num_tokens:\n        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]\n    else:\n        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]\n    gs_old = int(math.sqrt(len(posemb_grid)))\n\n    assert len(gs_new) >= 2\n    _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=mode, align_corners=False)\n    freq_new_pos_embed = posemb_grid.mean(dim=3, keepdim=True)\n    time_new_pos_embed = posemb_grid.mean(dim=2, keepdim=True)\n    _logger.info('New Position cls/dstl embedding %s', posemb_tok.shape)\n    _logger.info('New FREQ Position embedding %s', freq_new_pos_embed.shape)\n    _logger.info('New TIME Position embedding %s', time_new_pos_embed.shape)\n    return posemb_tok, freq_new_pos_embed, time_new_pos_embed\n\n\ndef checkpoint_filter_fn(state_dict, model):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    if 'model' in state_dict:\n        # For deit models\n        state_dict = state_dict['model']\n    state_dict = {k: v for k, v in state_dict.items()}\n    if \"time_new_pos_embed\" not in state_dict:\n        # we are working with ImageNet model\n        _logger.info(\"Adapting pos embedding from ImageNet pretrained model to PaSST.\")\n        v = state_dict.pop(\"pos_embed\")\n        new_pos_embed, freq_new_pos_embed, time_new_pos_embed = adapt_image_pos_embed_to_passt(\n            v, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)\n        state_dict[\"new_pos_embed\"] = new_pos_embed\n        state_dict[\"freq_new_pos_embed\"] = freq_new_pos_embed\n        state_dict[\"time_new_pos_embed\"] = time_new_pos_embed\n\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k and len(v.shape) < 4:\n            # For old models that I trained prior to conv based patchification\n            O, I, H, W = model.patch_embed.proj.weight.shape\n            v = v.reshape(O, -1, H, W)\n        elif k == 'pos_embed' and v.shape != model.pos_embed.shape:\n            # this should never occur\n            v = resize_pos_embed(\n                v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)\n        out_dict[k] = v\n    return out_dict\n\n\ndef _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):\n    default_cfg = default_cfg or default_cfgs[variant]\n    if kwargs.get('features_only', None):\n        raise RuntimeError('features_only not implemented for Vision Transformer models.')\n\n    # NOTE this extra code to support handling of repr size for in21k pretrained models\n    default_num_classes = default_cfg['num_classes']\n    num_classes = kwargs.get('num_classes', default_num_classes)\n    repr_size = kwargs.pop('representation_size', None)\n    if repr_size is not None and num_classes != default_num_classes:\n        # Remove representation layer if fine-tuning. This may not always be the desired action,\n        # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?\n        _logger.warning(\"Removing representation layer for fine-tuning.\")\n        repr_size = None\n\n    model = build_model_with_cfg(\n        PaSST, variant, pretrained,\n        default_cfg=default_cfg,\n        representation_size=repr_size,\n        pretrained_filter_fn=checkpoint_filter_fn,\n        pretrained_custom_load='npz' in default_cfg['url'],\n        **kwargs)\n    return model\n\n\ndef vit_huge_patch14_224_in21k(pretrained=False, **kwargs):\n    \"\"\" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.\n    NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights\n    \"\"\"\n    model_kwargs = dict(\n        patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)\n    model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)\n    return model\n\n\ndef deit_base_distilled_patch16_384(pretrained=False, **kwargs):\n    \"\"\" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).\n    ImageNet-1k weights from https://github.com/facebookresearch/deit.\n    \"\"\"\n    print(\"\\n\\n Loading DEIT BASE 384\\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = _create_vision_transformer(\n        'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=476 SWA \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (10, 10):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (10, 10):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768,\n                        depth=7, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (10, 10):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\ndef passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=4763 SWA \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (10, 10):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_swa_p16_128_ap4761', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_p16_128_ap472(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 10 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (10, 10):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_p16_128_ap472', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (12, 12):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_p16_s12_128_ap470', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):\n    print(\"\\n\\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = _create_vision_transformer(\n        'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):\n    print(\"\\n\\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = _create_vision_transformer(\n        'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 12 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (12, 12):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(12, 12)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_swa_p16_s12_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_p16_s14_128_ap469(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (14, 14):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_p16_s14_128_ap469', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_swa_p16_s14_128_ap471(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 14 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (14, 14):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(14, 14)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_swa_p16_s14_128_ap471', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_swa_p16_s16_128_ap473(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (16, 16):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_swa_p16_s16_128_ap473', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\ndef passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):\n    \"\"\" PaSST pre-trained on AudioSet\n    \"\"\"\n    print(\"\\n\\n Loading PaSST pre-trained on AudioSet Patch 16 stride 16 structured patchout mAP=472 \\n\\n\")\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    if model_kwargs.get(\"stride\") != (16, 16):\n        warnings.warn(\n            f\"This model was pre-trained with strides {(16, 16)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.\")\n    model = _create_vision_transformer(\n        'passt_s_p16_s16_128_ap468', pretrained=pretrained, distilled=True, **model_kwargs)\n    return model\n\n\nfrom ba3l.ingredients.ingredient import Ingredient\n\nmodel_ing = Ingredient(\"passt\")\n\nmodel_ing.add_config(instance_cmd=\"get_model\")\n\n\n@model_ing.command\ndef fix_embedding_layer(model, embed=\"default\"):\n    if embed == \"default\":\n        return model\n    if embed == \"overlap\":\n        model.patch_embed = PatchEmbedAdaptiveMean(replace=model.patch_embed)\n    if embed == \"am_keepconv\":\n        model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed)\n    return model\n\n@model_ing.command\ndef lighten_model(model, cut_depth=0):\n    if cut_depth == 0:\n        return model\n    if cut_depth:\n        if cut_depth < 0:\n            print(f\"\\n Reducing model depth by removing every  {-cut_depth} layer \\n\\n\")\n        else:\n            print(f\"\\n Reducing model depth by {cut_depth} \\n\\n\")\n            if len(model.blocks) < cut_depth + 2:\n                raise ValueError(f\"Cut depth a VIT with {len(model.blocks)} \"\n                                 f\"layers should be between 1 and {len(model.blocks) - 2}\")\n        print(f\"\\n Before Cutting it was  {len(model.blocks)} \\n\\n\")\n\n        old_blocks = list(model.blocks.children())\n        if cut_depth < 0:\n            print(f\"cut_depth={cut_depth}\")\n            old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]]\n        else:\n            old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:]\n        model.blocks = nn.Sequential(*old_blocks)\n        print(f\"\\n Atfer Cutting it is  {len(model.blocks)} \\n\\n\")\n    return model\n\n\n@model_ing.command\ndef get_model(arch=\"passt_s_kd_p16_128_ap486\", pretrained=True, n_classes=527, in_channels=1, fstride=10,\n              tstride=10,\n              input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=0, s_patchout_f=0,\n              ):\n    \"\"\"\n    :param arch: Base ViT or Deit architecture\n    :param pretrained: use pretrained model on imagenet\n    :param n_classes: number of classes\n    :param in_channels: number of input channels: 1 for mono\n    :param fstride: the patches stride over frequency.\n    :param tstride: the patches stride over time.\n    :param input_fdim: the expected input frequency bins.\n    :param input_tdim: the expected input time bins.\n    :param u_patchout: number of input patches to drop in Unstructured Patchout as defined in https://arxiv.org/abs/2110.05069\n    :param s_patchout_t: number of input time frames to drop Structured Patchout as defined in https://arxiv.org/abs/2110.05069\n    :param s_patchout_f:  number of input frequency bins to drop Structured Patchout as defined in https://arxiv.org/abs/2110.05069\n    :param audioset_pretrain: use pretrained models on Audioset.\n    :return:\n\n    \"\"\"\n    model_func = None\n    input_size = (input_fdim, input_tdim)\n    stride = (fstride, tstride)\n    if arch == \"passt_deit_bd_p16_384\":  # base deit\n        model_func = deit_base_distilled_patch16_384\n    elif arch == \"passt_s_kd_p16_128_ap486\":  # pretrained\n        model_func = passt_s_kd_p16_128_ap486\n    elif arch == \"passt_l_kd_p16_128_ap47\":  # pretrained passt-L\n        model_func = passt_l_kd_p16_128_ap47\n    elif arch == \"passt_s_swa_p16_128_ap476\":  # pretrained\n        model_func = passt_s_swa_p16_128_ap476\n    elif arch == \"passt_s_swa_p16_128_ap4761\":\n        model_func = passt_s_swa_p16_128_ap4761\n    elif arch == \"passt_s_p16_128_ap472\":\n        model_func = passt_s_p16_128_ap472\n    elif arch == \"passt_s_p16_s16_128_ap468\":\n        model_func = passt_s_p16_s16_128_ap468\n    elif arch == \"passt_s_swa_p16_s16_128_ap473\":\n        model_func = passt_s_swa_p16_s16_128_ap473\n    elif arch == \"passt_s_swa_p16_s14_128_ap471\":\n        model_func = passt_s_swa_p16_s14_128_ap471\n    elif arch == \"passt_s_p16_s14_128_ap469\":\n        model_func = passt_s_p16_s14_128_ap469\n    elif arch == \"passt_s_swa_p16_s12_128_ap473\":\n        model_func = passt_s_swa_p16_s12_128_ap473\n    elif arch == \"passt_s_p16_s12_128_ap470\":\n        model_func = passt_s_p16_s12_128_ap470\n    elif arch == \"passt_s_f128_20sec_p16_s10_ap474\":\n        model_func = passt_s_f128_20sec_p16_s10_ap474_swa\n    elif arch == \"passt_s_f128_30sec_p16_s10_ap473\":\n        model_func = passt_s_f128_30sec_p16_s10_ap473_swa\n\n    if model_func is None:\n        raise RuntimeError(f\"Unknown model {arch}\")\n    model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels,\n                       img_size=input_size, stride=stride, u_patchout=u_patchout,\n                       s_patchout_t=s_patchout_t, s_patchout_f=s_patchout_f)\n    model = fix_embedding_layer(model)\n    model = lighten_model(model)\n    print(model)\n    return model\n\n\nclass EnsembelerModel(nn.Module):\n    def __init__(self, models):\n        super(EnsembelerModel, self).__init__()\n        self.models = nn.ModuleList(models)\n\n    def forward(self, x):\n        # ModuleList can act as an iterable, or be indexed using ints\n        all_out = None\n        for i, m in enumerate(self.models):\n            out, _ = m(x)\n            if all_out is None:\n                all_out = out\n            else:\n                all_out = out + all_out\n        all_out = all_out / len(self.models)\n        return all_out, all_out\n\n\n@model_ing.command\ndef get_ensemble_model(arch_list=[]):\n    # arch_list = [(passt_s_swa_p16_128_ap476,fstride,tstride)]\n    models_list = [get_model(arch=arch, fstride=fstride, tstride=tstride) for arch, fstride, tstride in arch_list]\n    model = EnsembelerModel(models_list)\n    print(model)\n    return model\n"
  },
  {
    "path": "models/preprocess.py",
    "content": "import torch.nn as nn\nimport torchaudio\nfrom torch.nn.functional import conv1d, conv2d\n\nimport torch\n\n\nfrom ba3l.ingredients.ingredient import Ingredient\n\nmodel_ing = Ingredient(\"spectrograms\")\n\nsz_float = 4  # size of a float\nepsilon = 10e-8  # fudge factor for normalization\n\n\n\n\n@model_ing.command\nclass AugmentMelSTFT(nn.Module):\n    def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,\n                 htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=1, fmax_aug_range=1000):\n        torch.nn.Module.__init__(self)\n        # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e\n        # Similar config to the spectrograms used in AST: https://github.com/YuanGongND/ast\n\n        self.win_length = win_length\n        self.n_mels = n_mels\n        self.n_fft = n_fft\n        self.sr = sr\n        self.htk = htk\n        self.fmin = fmin\n        if fmax is None:\n            fmax = sr // 2 - fmax_aug_range // 2\n            print(f\"Warning: FMAX is None setting to {fmax} \")\n        self.fmax = fmax\n        self.norm = norm\n        self.hopsize = hopsize\n        self.register_buffer('window',\n                             torch.hann_window(win_length, periodic=False),\n                             persistent=False)\n        assert fmin_aug_range >= 1, f\"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation\"\n        assert fmin_aug_range >= 1, f\"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation\"\n        self.fmin_aug_range = fmin_aug_range\n        self.fmax_aug_range = fmax_aug_range\n\n        self.register_buffer(\"preemphasis_coefficient\", torch.as_tensor([[[-.97, 1]]]), persistent=False)\n        if freqm == 0:\n            self.freqm = torch.nn.Identity()\n        else:\n            self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)\n        if timem == 0:\n            self.timem = torch.nn.Identity()\n        else:\n            self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)\n\n\n    def forward(self, x):\n\n        x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)\n        x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,\n                       center=True, normalized=False, window=self.window, return_complex=False)\n        x = (x ** 2).sum(dim=-1)  # power mag\n        fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()\n        fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()\n        # don't augment eval data\n        if not self.training:\n            fmin = self.fmin\n            fmax = self.fmax\n\n\n        mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels,  self.n_fft, self.sr,\n                                        fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0)\n        mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),\n                                    device=x.device)\n        with torch.cuda.amp.autocast(enabled=False):\n            melspec = torch.matmul(mel_basis, x)\n\n        melspec = (melspec + 0.00001).log()\n\n        if self.training:\n            melspec = self.freqm(melspec)\n            melspec = self.timem(melspec)\n\n        melspec = (melspec + 4.5) / 5.  # fast normalization\n\n        return melspec\n\n    def extra_repr(self):\n        return 'winsize={}, hopsize={}'.format(self.win_length,\n                                               self.hopsize\n                                               )\n\n"
  },
  {
    "path": "openmic/README.md",
    "content": "# Experiments on OpenMIC-2018\n\n[OpenMIC-2018](https://github.com/cosmir/openmic-2018) ([zenodo](https://zenodo.org/record/1432913#.W6dPeJNKjOR)) is a dataset for polyphonic instruments identification.\n\n\n\n## Preparing the dataset\nUse `openmic2008/prepare_scripts/download_preprocess.py` to download and pack the dataset:\n```shell\ncd openmic2008/prepare_scripts/\npython download_preprocess.py\n```\nWhen the script completes, you should have two files inside `audioset_hdf5s/mp3/`\n `openmic_train.csv_mp3.hdf` and `openmic_test.csv_mp3.hdf`\nthese files contains the mp3s of the dataset and the labels.\n\n\n\n## Fine-tuning pretrained PaSST on the openmic2008\n\nSimilar to audioset you can use:\n```shell\n# Example call with all the default config:\npython ex_openmic.py with  trainer.precision=16  -p \n```\n\n```shell\n# with 2 gpus:\nDDP=2 python ex_openmic.py with  trainer.precision=16  -p \n```\n"
  },
  {
    "path": "openmic/dataset.py",
    "content": "import io\nimport os\nimport pathlib\nimport random\n\nimport av\nimport librosa\nimport torchaudio\nfrom torch.utils.data import Dataset as TorchDataset, ConcatDataset, DistributedSampler, WeightedRandomSampler\n\nimport torch\nfrom ba3l.ingredients.datasets import Dataset\nimport pandas as pd\nfrom sacred.config import DynamicIngredient, CMD\nfrom scipy.signal import convolve\nfrom sklearn import preprocessing\nfrom torch.utils.data import Dataset as TorchDataset\nimport numpy as np\nimport h5py\nfrom helpers.audiodatasets import  PreprocessDataset\n\n\nLMODE = os.environ.get(\"LMODE\", False)\n\ndataset = Dataset('openMIC')\n\n\n@dataset.config\ndef default_config():\n    name = 'openmic2008'  # dataset name\n    normalize = False  # normalize dataset\n    subsample = False  # subsample squares from the dataset\n    roll = True  # apply roll augmentation\n    fold = 1\n    base_dir = \"audioset_hdf5s/\"  # base directory of the dataset as downloaded\n    if LMODE:\n        base_dir = \"/system/user/publicdata/CP/audioset/audioset_hdf5s/\"\n    openmic_train_hdf5 = base_dir + \"mp3/openmic_train.csv_mp3.hdf\"\n    openmic_test_hdf5 = base_dir + \"mp3/openmic_test.csv_mp3.hdf\"\n    ir_path = base_dir + \"irs/\"\n    num_of_classes = 20\n\n\n\n\n\ndef decode_mp3(mp3_arr):\n    \"\"\"\n    decodes an array if uint8 representing an mp3 file\n    :rtype: np.array\n    \"\"\"\n    container = av.open(io.BytesIO(mp3_arr.tobytes()))\n    stream = next(s for s in container.streams if s.type == 'audio')\n    # print(stream)\n    a = []\n    for i, packet in enumerate(container.demux(stream)):\n        for frame in packet.decode():\n            a.append(frame.to_ndarray().reshape(-1))\n    waveform = np.concatenate(a)\n    if waveform.dtype != 'float32':\n        raise RuntimeError(\"Unexpected wave type\")\n    return waveform\n\n\ndef pad_or_truncate(x, audio_length):\n    \"\"\"Pad all audio to specific length.\"\"\"\n    if len(x) <= audio_length:\n        return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0)\n    else:\n        return x[0: audio_length]\n\n\nirs_arr = None\n\n\n@dataset.command\ndef get_ir_sample(ir_path, _run, ir_augment, cut_irs_offset=None):\n    if not ir_augment:\n        return\n    global irs_arr\n    if irs_arr is None:\n        all_paths = [path for path in pathlib.Path(os.path.expanduser(ir_path)).rglob('*.wav')]\n        all_paths = sorted(all_paths)\n        if cut_irs_offset is not None:\n            all_paths = all_paths[cut_irs_offset:cut_irs_offset + 10]\n        all_paths_name = [str(p).rsplit(\"/\", 1)[-1] for p in all_paths]\n        print(\"will use these IRs:\")\n        for i in range(len(all_paths_name)):\n            print(i, \": \", all_paths_name[i])\n        _run.info[\"ir_devices\"] = all_paths_name\n        irs_arr = [librosa.load(p, sr=32000)[0] for p in all_paths]\n    return irs_arr[int(np.random.randint(0, len(irs_arr)))]\n\n\n@dataset.command\ndef pydub_augment(waveform, gain_augment=7, ir_augment=0):\n    if ir_augment and torch.rand(1) < ir_augment:\n        ir = get_ir_sample()\n        waveform = convolve(waveform, ir, 'full')\n    if gain_augment:\n        gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment\n        amp = 10 ** (gain / 20)\n        waveform = waveform * amp\n    return waveform\n\n\nclass MixupDataset(TorchDataset):\n    \"\"\" Mixing Up wave forms\n    \"\"\"\n\n    def __init__(self, dataset, beta=2, rate=0.5):\n        self.beta = beta\n        self.rate = rate\n        self.dataset = dataset\n        print(f\"Mixing up waveforms from dataset of len {len(dataset)}\")\n\n    def __getitem__(self, index):\n        x1, f1, y1 = self.dataset[index]\n        y1 = torch.as_tensor(y1)\n        if torch.rand(1) < self.rate:\n            idx2 = torch.randint(len(self.dataset), (1,)).item()\n            x2, f2, y2 = self.dataset[idx2]\n            y2 = torch.as_tensor(y2)\n            l = np.random.beta(self.beta, self.beta)\n            l = max(l, 1. - l)\n            x1 = x1 - x1.mean()\n            x2 = x2 - x2.mean()\n            x = (x1 * l + x2 * (1. - l))\n            x = x - x.mean()\n            assert len(y1) == 40, \"only for openmic, works this\"\n            y_mask1 = (torch.as_tensor(y1[20:]) > 0.5).float()\n            y_mask2 = (torch.as_tensor(y2[20:]) > 0.5).float()\n            y1[:20] *= y_mask1\n            y2[:20] *= y_mask2\n            yres = (y1 * l + y2 * (1. - l))\n            yres[20:] = torch.stack([y_mask1, y_mask2]).max(dim=0).values\n            return x, f1, yres\n        return x1, f1, y1\n\n    def __len__(self):\n        return len(self.dataset)\n\n\n\nclass AudioSetDataset(TorchDataset):\n    def __init__(self, hdf5_file, sample_rate=32000, classes_num=527, clip_length=10, augment=False, in_mem=False):\n        \"\"\"\n        Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav\n        \"\"\"\n        self.sample_rate = sample_rate\n        self.hdf5_file = hdf5_file\n        if in_mem:\n            print(\"\\nPreloading in memory\\n\")\n            with open(hdf5_file, 'rb') as f:\n                self.hdf5_file = io.BytesIO(f.read())\n        with h5py.File(hdf5_file, 'r') as f:\n            self.length = len(f['audio_name'])\n            print(f\"Dataset from {hdf5_file} with length {self.length}.\")\n        self.dataset_file = None  # lazy init\n        self.clip_length = clip_length * sample_rate\n        self.classes_num = classes_num\n        self.augment = augment\n        if augment:\n            print(f\"Will agument data from {hdf5_file}\")\n\n    def open_hdf5(self):\n        self.dataset_file = h5py.File(self.hdf5_file, 'r')\n\n    def __len__(self):\n        return self.length\n\n    def __del__(self):\n        if self.dataset_file is not None:\n            self.dataset_file.close()\n            self.dataset_file = None\n\n    def __getitem__(self, index):\n        \"\"\"Load waveform and target of an audio clip.\n\n        Args:\n          meta: {\n            'hdf5_path': str,\n            'index_in_hdf5': int}\n        Returns:\n          data_dict: {\n            'audio_name': str,\n            'waveform': (clip_samples,),\n            'target': (classes_num,)}\n        \"\"\"\n        if self.dataset_file is None:\n            self.open_hdf5()\n\n        audio_name = self.dataset_file['audio_name'][index].decode()\n        waveform = decode_mp3(self.dataset_file['mp3'][index])\n        if self.augment:\n            waveform = pydub_augment(waveform)\n        waveform = pad_or_truncate(waveform, self.clip_length)\n        waveform = self.resample(waveform)\n        target = self.dataset_file['target'][index]\n        target = target.astype(np.float32)\n        return waveform.reshape(1, -1), audio_name, target\n\n    def resample(self, waveform):\n        \"\"\"Resample.\n        Args:\n          waveform: (clip_samples,)\n        Returns:\n          (resampled_clip_samples,)\n        \"\"\"\n        if self.sample_rate == 32000:\n            return waveform\n        elif self.sample_rate == 16000:\n            return waveform[0:: 2]\n        elif self.sample_rate == 8000:\n            return waveform[0:: 4]\n        else:\n            raise Exception('Incorrect sample rate!')\n\n\n@dataset.command\ndef get_base_training_set(openmic_train_hdf5):\n    ds = AudioSetDataset(openmic_train_hdf5, augment=True)\n    return ds\n\n\n@dataset.command\ndef get_ft_weighted_sampler(samples_weights=CMD(\".get_ft_cls_balanced_sample_weights\"),\n                            epoch_len=100000, sampler_replace=False):\n    num_nodes = int(os.environ.get('num_nodes', 1))\n    ddp = int(os.environ.get('DDP', 1))\n    num_nodes = max(ddp, num_nodes)\n    print(\"num_nodes= \", num_nodes)\n    rank = int(os.environ.get('NODE_RANK', 0))\n    return DistributedSamplerWrapper(sampler=WeightedRandomSampler(samples_weights,\n                                                                   num_samples=epoch_len, replacement=sampler_replace),\n                                     dataset=range(epoch_len),\n                                     num_replicas=num_nodes,\n                                     rank=rank,\n                                     )\n\n\n@dataset.command\ndef get_base_test_set(openmic_test_hdf5):\n    ds = AudioSetDataset(openmic_test_hdf5)\n    return ds\n\n\n@dataset.command(prefix='roll_conf')\ndef get_roll_func(axis=1, shift=None, shift_range=50):\n    print(\"rolling...\")\n\n    def roll_func(b):\n        x, i, y = b\n        x = torch.as_tensor(x)\n        sf = shift\n        if shift is None:\n            sf = int(np.random.random_integers(-shift_range, shift_range))\n        global FirstTime\n\n        return x.roll(sf, axis), i, y\n\n    return roll_func\n\n\n@dataset.command\ndef get_training_set(normalize, roll, wavmix=False):\n    ds = get_base_training_set()\n    get_ir_sample()\n    if normalize:\n        print(\"normalized train!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    if roll:\n        ds = PreprocessDataset(ds, get_roll_func())\n    if wavmix:\n        ds = MixupDataset(ds)\n\n    return ds\n\n\n@dataset.command\ndef get_test_set(normalize):\n    ds = get_base_test_set()\n    if normalize:\n        print(\"normalized test!\")\n        fill_norms()\n        ds = PreprocessDataset(ds, norm_func)\n    return ds\n\n\n@dataset.command\ndef print_conf(_config):\n    print(\"Config of \", dataset.path, id(dataset))\n    print(_config)\n    print()\n\n\nclass DistributedSamplerWrapper(DistributedSampler):\n    def __init__(\n            self, sampler, dataset,\n            num_replicas=None,\n            rank=None,\n            shuffle: bool = True):\n        super(DistributedSamplerWrapper, self).__init__(\n            dataset, num_replicas, rank, shuffle)\n        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238\n        self.sampler = sampler\n\n    def __iter__(self):\n        if self.sampler.generator is None:\n            self.sampler.generator = torch.Generator()\n        self.sampler.generator.manual_seed(self.seed + self.epoch)\n        indices = list(self.sampler)\n        if self.epoch == 0:\n            print(f\"\\n DistributedSamplerWrapper :  {indices[:10]} \\n\\n\")\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        return iter(indices)\n\n\nif __name__ == \"__main__\":\n    from sacred import Experiment\n\n    ex = Experiment(\"test_dataset\", ingredients=[dataset])\n\n\n    @ex.automain\n    def default_command():\n        ex.current_run.get_command_function(\"print_config\")()\n        get_base_training_set()\n        ds = get_test_set()\n        print(ds[0])\n        ds = get_training_set()\n        print(ds[0])\n        print(\"get_base_training_set\", len(get_base_training_set()))\n        print(\"get_base_test_set\", len(get_base_test_set()))\n        print(\"get_training_set\", len(get_training_set()))\n        print(\"get_test_set\", len(get_test_set()))\n"
  },
  {
    "path": "openmic/prepare_scripts/download_preprocess.py",
    "content": "import os\nimport tarfile\nimport multiprocessing\nimport glob\nimport h5py\nimport numpy as np\n\nfrom torch.hub import download_url_to_file\n\n# global constants\nopenmicurl = \"https://zenodo.org/record/1432913/files/openmic-2018-v1.0.0.tgz?download=1\"\ndownload_target = \"openmic-2018-v1.0.0.tgz\"\nextract_target = download_target.replace(\".tgz\", \"\")\ndataset_path = os.path.join(extract_target, \"openmic-2018/\")\nmp3_path = os.path.join(dataset_path, \"mp3/\")\nhdf5s_dir = \"../../audioset_hdf5s/\"\ntrain_files_csv = os.path.join(dataset_path, \"partitions/split01_train.csv\")\ntest_files_csv = os.path.join(dataset_path, \"partitions/split01_test.csv\")\n\n\ndef download(force=False):\n    if force or not os.path.isfile(download_target):\n        print(\"Downloading OpenMIC from zenodo...\")\n        download_url_to_file(openmicurl, download_target)\n    else:\n        print(f\"{download_target} already exists. Skipping download!\")\n\n\ndef untar():\n    my_tar = tarfile.open(download_target)\n    print(f\"Extracting openmic from {download_target} to {extract_target}\")\n\n    my_tar.extractall(extract_target)\n\n\ndef process_folder(fol=\"balanced_train_segments\"):\n    print(\"now working on \", fol)\n    os.makedirs(mp3_path + fol, exist_ok=True)\n    all_files = list(glob.glob(os.path.join(dataset_path, \"audio/\") + \"/*/*.ogg\"))  # openmic format\n    print(f\"it has {len(all_files)}\")\n    print(all_files[:5])\n    global all_num\n    all_num = len(all_files)\n    cmds = [(i, file, mp3_path + fol + \"/\" + os.path.basename(file)[:-3]) for i, file in enumerate(all_files)]\n    print(cmds[0])\n    with multiprocessing.Pool(processes=20) as pool:\n        pool.starmap(process_one, cmds)\n\n\ndef process_one(i, f1, f2):\n    if i % 100 == 0:\n        print(f\"{i}/{all_num} \\t\", f1)\n    os.system(f\"ffmpeg  -hide_banner -nostats -loglevel error -n -i {f1} -codec:a mp3 -ar 32000 {f2}mp3\")\n\n\ndef make_mp3():\n    process_folder(\"audio\")\n\n\ndef read_metadata(csv_path, classes_num, id_to_ix, openmicf):\n    \"\"\"Read metadata of AudioSet from a csv file.\n    source: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/d2f4b8c18eab44737fcc0de1248ae21eb43f6aa4/utils/utilities.py#L59\n    Args:\n      csv_path: str\n    Returns:\n      meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)}\n    \"\"\"\n\n    with open(csv_path) as file:\n        lines = file.readlines()\n        lines = [line.rstrip() for line in lines]\n    audios_num = len(lines)\n    # class + mask\n    targets = np.zeros((audios_num, classes_num * 2), dtype=np.float32)\n    audio_names = []\n    notfound = set()\n    for i, row in enumerate(lines):\n        audio_name = '{}.mp3'.format(row)  # Audios are started with an extra 'Y' when downloading\n        audio_names.append(audio_name)\n        t = openmicf.Y_true[id_to_ix[row]] #id_to_ix[row][\"true\"]\n        m = openmicf.Y_mask[id_to_ix[row]].astype(int)#id_to_ix[row][\"mask\"]\n        # Target\n        targets[i, :classes_num] = t\n        targets[i, classes_num:] = m\n\n    print(notfound)\n    print(\"original_targets\", len(targets))\n    mask = targets.astype(np.int).sum(1) > 0\n    print(len(mask), mask.sum())\n    print(\"after: \", len(targets[mask]))\n    meta_dict = {'audio_name': np.array(audio_names)[mask], 'target': targets[mask]}\n    return meta_dict\n\n\ndef get_files_labels(balanced_csv, balanced_audio_path, d_files, openmicf, prefix=None, zip_contents=None, classes_num=20):\n    meta_csv = read_metadata(balanced_csv, classes_num, d_files, openmicf)\n    # fname,labels,mids\n    audios_num = len(meta_csv['audio_name'])\n    found = 0\n    notfound = 0\n    available_files = []\n    available_targets = []\n    for n in range(audios_num):\n        audio_path = meta_csv['audio_name'][n]\n        # print(balanced_audio_path + f\"{prefix}/{audio_path}\")\n        if n == 0:\n            print(\"checking: \", balanced_audio_path + f\"{prefix}/{audio_path}\")\n        if os.path.isfile(balanced_audio_path + f\"{prefix}/{audio_path}\"):\n            found += 1\n            available_files.append(meta_csv['audio_name'][n])\n            available_targets.append(meta_csv['target'][n])\n        else:\n            notfound += 1\n    print(f\"Found {found} . not found {notfound}\")\n    return available_files, available_targets\n\n\ndef pack():\n    d_files = dict()\n    opmic = np.load(os.path.join(dataset_path, \"openmic-2018.npz\"))\n    opmic.allow_pickle = True\n\n    for i, sid in enumerate(opmic.f.sample_key):\n\n        d_files[sid] = i #{\"mask\": opmic.f.Y_mask[i].astype(int),\n                        #\"true\": opmic.f.Y_true[i]}\n    print(\"len=\",len(d_files))\n\n    for read_file, prefix in [(train_files_csv, \"audio/\"), (test_files_csv, \"audio/\")]:\n        print(\"now working on \", read_file, prefix)\n        # files, y = torch.load(read_file+\".pth\")\n        files, y = get_files_labels(read_file, mp3_path, d_files=d_files, openmicf=opmic.f, prefix=prefix)\n        y = np.array(y)\n        # y = np.packbits(y, axis=-1)\n        packed_len = y.shape[1]\n        print(files[0], \"classes: \", packed_len, y.dtype)\n        available_size = len(files)\n        f = files[0]\n        a = np.fromfile(mp3_path + prefix + \"/\" + f, dtype='uint8')\n\n        dt = h5py.vlen_dtype(np.dtype('uint8'))\n        save_file = read_file.rsplit(\"/\", 1)[1].replace(\"split01\", \"openmic\")\n        os.makedirs(hdf5s_dir + \"mp3/\" ,exist_ok=True)\n        if os.path.isfile(hdf5s_dir + \"mp3/\" + save_file + \"_mp3.hdf\"):\n            print(hdf5s_dir + \"mp3/\" + save_file + \"_mp3.hdf\", \"exists!\\n\\n\\n contiue\")\n            continue\n        with h5py.File(hdf5s_dir + \"mp3/\" + save_file + \"_mp3.hdf\", 'w') as hf:\n            audio_name = hf.create_dataset('audio_name', shape=((available_size,)), dtype='S20')\n            waveform = hf.create_dataset('mp3', shape=((available_size,)), dtype=dt)\n            target = hf.create_dataset('target', shape=((available_size, packed_len)), dtype=y.dtype)\n            for i, file in enumerate(files):\n                if i % 1000 == 0:\n                    print(f\"{i}/{available_size}\")\n                f = file\n                a = np.fromfile(mp3_path + prefix + f, dtype='uint8')\n                audio_name[i] = f\n                waveform[i] = a\n                target[i] = y[i]\n        print(\"Saved h5py file into \", hdf5s_dir + \"mp3/\" + save_file + \"_mp3.hdf\")\n        print(a.shape)\n        print(\"Done!\", prefix)\n\n\ndef preprocess():\n    download()\n    untar()\n    make_mp3()\n    pack()\n\n\npreprocess()"
  },
  {
    "path": "pip_list.txt",
    "content": "Package                 Version\n----------------------- ------------\nabsl-py                 1.4.0\naiohttp                 3.8.4\naiosignal               1.3.1\nappdirs                 1.4.4\nasync-timeout           4.0.2\nattrs                   22.2.0\naudioread               3.0.0\nautopep8                1.6.0\nav                      10.0.0\nbrotlipy                0.7.0\ncachetools              5.3.0\ncertifi                 2022.12.7\ncffi                    1.15.1\ncharset-normalizer      2.0.4\nclick                   8.1.3\ncolorama                0.4.6\ncryptography            39.0.1\ndecorator               5.1.1\ndocker-pycreds          0.4.0\ndocopt                  0.6.2\nflit_core               3.8.0\nfrozenlist              1.3.3\nfsspec                  2023.3.0\nfuture                  0.18.3\ngitdb                   4.0.10\nGitPython               3.1.31\ngoogle-auth             2.17.1\ngoogle-auth-oauthlib    0.4.6\ngrpcio                  1.53.0\nh5py                    3.8.0\nidna                    3.4\nimportlib-metadata      6.1.0\njoblib                  1.2.0\njsonpickle              3.0.1\nkk-sacred               0.8.4\nlazy_loader             0.2\nlibrosa                 0.10.0.post2\nlightning-utilities     0.8.0\nllvmlite                0.39.1\nMarkdown                3.4.3\nMarkupSafe              2.1.2\nmkl-fft                 1.3.1\nmkl-random              1.2.2\nmkl-service             2.4.0\nmsgpack                 1.0.5\nmultidict               6.0.4\nmunch                   2.5.0\nnumba                   0.56.4\nnumpy                   1.23.5\noauthlib                3.2.2\npackaging               23.0\npandas                  1.5.3\npathtools               0.1.2\nPillow                  9.4.0\npip                     23.0.1\npooch                   1.6.0\nprotobuf                4.22.1\npsutil                  5.9.4\npy-cpuinfo              9.0.0\npyasn1                  0.4.8\npyasn1-modules          0.2.8\npycodestyle             2.10.0\npycparser               2.21\npyDeprecate             0.3.0\npyOpenSSL               23.0.0\nPySocks                 1.7.1\npython-dateutil         2.8.2\npytorch-lightning       1.3.1\npytz                    2023.3\nPyYAML                  5.4.1\nrequests                2.28.1\nrequests-oauthlib       1.3.1\nrsa                     4.9\nscikit-learn            1.2.2\nscipy                   1.10.1\nsentry-sdk              1.19.1\nsetproctitle            1.3.2\nsetuptools              65.6.3\nsix                     1.16.0\nsmmap                   5.0.0\nsoundfile               0.12.1\nsoxr                    0.3.4\ntensorboard             2.12.0\ntensorboard-data-server 0.7.0\ntensorboard-plugin-wit  1.8.1\nthreadpoolctl           3.1.0\ntimm                    0.4.12\ntoml                    0.10.2\ntorch                   1.11.0\ntorchaudio              0.11.0\ntorchmetrics            0.2.0\ntorchvision             0.12.0\ntqdm                    4.65.0\ntyping_extensions       4.4.0\nurllib3                 1.26.15\nwandb                   0.14.2\nWerkzeug                2.2.3\nwheel                   0.38.4\nwrapt                   1.15.0\nyarl                    1.8.2\nzipp                    3.15.0\n"
  },
  {
    "path": "requirements.txt",
    "content": "av>=10.0.0\nh5py>=3.8.0\njsonpickle>=3.0.1\nkk-sacred>=0.8.4\nlibrosa>=0.10.0.post2\ntimm>=0.4.12\ntorchmetrics>=0.2.0\npytorch-lightning<2.0.0\nwandb>=0.14.2\n"
  }
]