[
  {
    "path": ".flake8",
    "content": "[flake8]\nmax-line-length = 88\nextend-ignore = E203"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "* @deezer @Faylixe @romi1502 @mmoussallam @alreadytaikeune"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "# How-to contribute\n\nThose are the main contributing guidelines for contributing to this project:\n\n- Verify that your contribution does not embark proprietary code or infringe any copyright of any sort.\n- Avoid adding any unnecessary dependencies to the project, espcially of those are not easily packaged and installed through `conda` or `pip`.\n- Python contributions must follow the [PEP 8 style guide](https://www.python.org/dev/peps/pep-0008/).\n- Use [Pull Request](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests) mechanism and please be patient while waiting for reviews.\n- Remain polite and civil in all exchanges with the maintainers and other contributors.\n- Any issue submitted which does not respect provided template, or lack of information, will be considered as invalid and automatically closed.\n\n## Get started\n\nThis project is managed using [Poetry](https://python-poetry.org/docs/basic-usage/),\nin order to contribute, the safest is to create your\n[own fork of spleeter](https://help.github.com/en/github/getting-started-with-github/fork-a-repo) first and then setup your development environment:\n\n```bash\n# Clone spleeter repository fork\ngit clone https://github.com/<your_name>/spleeter && cd spleeter\n# Install poetry\npip install poetry\n# Install spleeter dependencies\npoetry install\n# Run unit test suite\npoetry run pytest tests/\n```\n\nYou can then make your changes and experiment freely. Once you're done, remember to check that the tests still run. If you've added a new feature, add tests!\n\nThen finally, you're more than welcome to create a [Pull Request](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request-from-a-fork) in **Spleeter** main repo. We will look at it as soon as possible and eventually integrate your changes in the project.\n\n## PR requirements\n\nFollowing command should be ran successfully before to consider a PR for merging:\n\n```bash\npoetry run pytest tests/\npoetry run black spleeter\npoetry run isort spleeter\n```\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug.md",
    "content": "---\nname: Bug\nabout: Report a bug\ntitle: \"[Bug] name your bug\"\nlabels: bug, invalid\n---\n\n- [ ] I didn't find a similar issue already open.\n- [ ] I read the documentation (README AND Wiki)\n- [ ] I have installed FFMpeg\n- [ ] My problem is related to Spleeter only, not a derivative product (such as Webapplication, or GUI provided by others)\n\n## Description\n\n<!-- Give us a clear and concise description of the bug you are reporting. -->\n\n## Step to reproduce\n\n<!-- Indicates clearly steps to reproduce the behavior: -->\n\n1. Installed using `...`\n2. Run as `...`\n3. Got `...` error\n\n## Output\n\n```bash\nShare what your terminal says when you run the script (as well as what you would expect).\n```\n\n## Environment\n\n<!-- Fill the following table -->\n\n|                   |                                 |\n| ----------------- | ------------------------------- |\n| OS                | Windows / Linux / MacOS / other |\n| Installation type | Conda / pip / other             |\n| RAM available     | XGo                             |\n| Hardware spec     | GPU / CPU / etc ...             |\n\n## Additional context\n\n<!-- Add any other context about the problem here, references, cites, etc.. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/discussion.md",
    "content": "---\nname: Discussion\nabout: Ideas sharing or theorical question solving \nlabels: question\ntitle: \"[Discussion] your question\"\n---\n\n<!-- Please respect the title [Discussion] tag. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature.md",
    "content": "---\nname: Feature request\nabout: Submit idea for new feature\nlabels: feature, enhancement\ntitle: \"[Feature] your feature name\"\n---\n\n## Description\n\n<!-- Describe your feature request here. -->\n\n## Additional information\n\n<!-- Add any additional description -->\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# Pull request title\n\n- [ ] I read [contributing guideline](https://github.com/deezer/spleeter/blob/master/.github/CONTRIBUTING.md)\n- [ ] I didn't find a similar pull request already open.\n- [ ] My PR is related to Spleeter only, not a derivative product (such as Webapplication, or GUI provided by others)\n\n## Description\n\nA few sentences describing the overall goals of the pull request's commits.\n\n## How this patch was tested\n\nYou tested it, right?\n\n- [ ] I implemented unit test whicn ran successfully using `poetry run pytest tests/`\n- [ ] Code has been formatted using `poetry run black spleeter`\n- [ ] Imports has been formatted using `poetry run isort spleeter``\n\n## Documentation link and external references\n\nPlease provide any info that may help us better understand your code.\n"
  },
  {
    "path": ".github/workflows/conda.yml",
    "content": "name: conda\non:\n  - workflow_dispatch\njobs:\n  build-linux:\n    strategy:\n      matrix:\n        python: [3.7, 3.8]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python ${{ matrix.python }}\n        uses: actions/setup-python@v2\n        with:\n          python-version: ${{ matrix.python }}\n      - name: Install dependencies\n        run: |\n          $CONDA/bin/conda install conda-build\n          $CONDA/bin/conda install anaconda-client\n      - name: Build package\n        run: |\n          $CONDA/bin/conda config --add channels anaconda\n          $CONDA/bin/conda config --add channels conda-forge\n          $CONDA/bin/conda build --python ${{ matrix.python }} conda/spleeter\n      - name: Push package\n        run: |\n          $CONDA/bin/anaconda login --username ${{ secrets.ANACONDA_USERNAME }} --password ${{ secrets.ANACONDA_PASSWORD }}\n          for package in /usr/share/miniconda/conda-bld/linux-64/spleeter*.bz2; do\n            $CONDA/bin/anaconda upload $package\n          done\n  build-windows:\n    strategy:\n      matrix:\n        python: [3.7]\n    runs-on: windows-latest\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python ${{ matrix.python }}\n        uses: actions/setup-python@v2\n        with:\n          python-version: ${{ matrix.python }}\n      - name: Install dependencies\n        run: |\n          C:\\Miniconda\\condabin\\conda.bat init powershell\n          C:\\Miniconda\\condabin\\conda.bat install conda-build\n          C:\\Miniconda\\condabin\\conda.bat install anaconda-client\n      - name: Build package\n        run: |\n          C:\\Miniconda\\condabin\\conda.bat config --add channels anaconda\n          C:\\Miniconda\\condabin\\conda.bat config --add channels conda-forge\n          C:\\Miniconda\\condabin\\conda.bat build --python ${{ matrix.python }} conda\\spleeter\n      - name: Push package\n        run: |\n          anaconda login --username ${{ secrets.ANACONDA_USERNAME }} --password ${{ secrets.ANACONDA_PASSWORD }}\n          $packages = Get-ChildItem \"C:\\Miniconda\\conda-bld\\win-64\\\"\n          foreach ($package in $packages){\n            anaconda upload $package.FullName\n          }\n"
  },
  {
    "path": ".github/workflows/docker.yml",
    "content": "name: docker\non:\n  workflow_dispatch:\n    inputs:\n      version:\n        description: \"Spleeter version to build image for\"\n        required: true\n        default: \"2.1.2\"\njobs:\n  cuda-base:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        distribution: [3.6, 3.7, 3.8]\n      fail-fast: true\n    steps:\n      - uses: actions/checkout@v2\n      - name: Build CUDA base image\n        run: |\n          docker build \\\n            --build-arg BASE=python:${{ matrix.distribution }} \\\n            -t deezer/python-cuda-10-1:${{ matrix.distribution }} \\\n            -f docker/cuda-10-1.dockerfile .\n      - name: Docker login\n        run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin\n      - name: Push deezer/python-cuda-10-1:${{ matrix.distribution }} image\n        run: docker push deezer/python-cuda-10-1:${{ matrix.distribution }}\n  pip-images:\n    needs: cuda-base\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        platform: [cpu, gpu]\n        distribution: [3.6, 3.7, 3.8]\n      fail-fast: true\n    steps:\n      - uses: actions/checkout@v2\n      - if: ${{ matrix.platform == 'cpu' }}\n        run: |\n          echo \"base=python:${{ matrix.distribution }}\" >> $GITHUB_ENV\n          echo \"image=spleeter\" >> $GITHUB_ENV\n      - if: ${{ matrix.platform == 'gpu' }}\n        run: |\n          echo \"base=deezer/python-cuda-10-1:${{ matrix.distribution }}\" >> $GITHUB_ENV\n          echo \"image=spleeter-gpu\" >> $GITHUB_ENV\n      - name: Build deezer/${{ env.image }}:${{ matrix.distribution }} image\n        run: |\n          docker build \\\n            --build-arg BASE=${{ env.base }} \\\n            --build-arg SPLEETER_VERSION=${{ github.event.inputs.version }} \\\n            -t deezer/${{ env.image }}:${{ matrix.distribution }} \\\n            -f docker/spleeter.dockerfile .\n      - if: ${{ matrix.distribution == '3.8' }}\n        run: |\n          docker tag deezer/${{ env.image }}:${{ matrix.distribution }} deezer/${{ env.image }}:latest\n      - name: Test deezer/${{ env.image }}:${{ matrix.distribution }} image\n        run: |\n          docker run \\\n            -v $(pwd):/runtime \\\n            deezer/${{ env.image }}:${{ matrix.distribution }} \\\n            separate -o /tmp /runtime/audio_example.mp3\n      - name: Docker login\n        run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin\n      - name: Push deezer/${{ env.image }}:${{ matrix.distribution }} image\n        run: docker push deezer/${{ env.image }}:${{ matrix.distribution }}\n  conda-images:\n    needs: cuda-base\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        platform: [cpu, gpu]\n      fail-fast: true\n    steps:\n      - uses: actions/checkout@v2\n      - if: ${{ matrix.platform == 'cpu' }}\n        name: Build Conda base image\n        run: |\n          docker build -t conda:cpu -f docker/conda.dockerfile .\n          echo \"image=spleeter\" >> $GITHUB_ENV\n      - if: ${{ matrix.platform == 'gpu' }}\n        name: Build Conda base image\n        run: |\n          docker build --build-arg BASE=deezer/python-cuda-10-1:3.8 -t conda:gpu -f docker/conda.dockerfile .\n          echo \"image=spleeter-gpu\" >> $GITHUB_ENV\n      - name: Build deezer/${{ env.image }}:${{ env.tag }} image\n        run: |\n          docker build \\\n            --build-arg BASE=conda:${{ matrix.platform }} \\\n            --build-arg SPLEETER_VERSION=${{ github.event.inputs.version }} \\\n            -t deezer/${{ env.image }}:conda \\\n            -f docker/spleeter-conda.dockerfile .\n      - name: Docker login\n        run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin\n      - name: Push deezer/${{ env.image }}:conda image\n        run: docker push deezer/${{ env.image }}:conda\n  images-with-model:\n    needs: [pip-images, conda-images]\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        platform: [cpu, gpu]\n        distribution: [3.6, 3.7, 3.8]\n        model: [2stems, 4stems, 5stems]\n      fail-fast: true\n    steps:\n      - uses: actions/checkout@v2\n      - if: ${{ matrix.platform == 'cpu' }}\n        run: echo \"image=spleeter\" >> $GITHUB_ENV\n      - if: ${{ matrix.platform == 'gpu' }}\n        run: echo \"image=spleeter-gpu\" >> $GITHUB_ENV\n      - name: Build deezer/${{ env.image }}:${{ matrix.distribution }}-${{ matrix.model }} image\n        run: |\n          docker build \\\n            --build-arg BASE=deezer/${{ env.image }}:${{ matrix.distribution }} \\\n            --build-arg MODEL=${{ matrix.model }} \\\n            -t deezer/${{ env.image }}:${{ matrix.distribution }}-${{ matrix.model }} \\\n            -f docker/spleeter-model.dockerfile .\n      - name: Test deezer/${{ env.image }}:${{ matrix.distribution }}-${{ matrix.model }} image\n        run: |\n          docker run \\\n            -v $(pwd):/runtime \\\n            deezer/${{ env.image }}:${{ matrix.distribution }} \\\n            separate -o /tmp -p spleeter:${{ matrix.model }} /runtime/audio_example.mp3\n      - name: Docker login\n        run: echo ${{ secrets.DOCKERHUB_PASSWORD }} | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin\n      - name: Push deezer/${{ env.image }}:${{ matrix.distribution }}-${{ matrix.model }} image\n        run: docker push deezer/${{ env.image }}:${{ matrix.distribution }}-${{ matrix.model }}\n"
  },
  {
    "path": ".github/workflows/pypi.yml",
    "content": "name: pypi\non:\n  - workflow_dispatch\nenv:\n  PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}\njobs:\n  package-and-deploy:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - uses: actions/setup-python@v2\n        with:\n          python-version: 3.9\n      - name: Install Poetry\n        run: |\n          pip install poetry\n          poetry config virtualenvs.in-project false\n          poetry config virtualenvs.path ~/.virtualenvs\n          poetry config pypi-token.pypi $PYPI_TOKEN\n      - name: Deploy to pypi\n        run: |\n          poetry build\n          poetry publish\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: test\non:\n  pull_request:\n    branches:\n      - master\njobs:\n  tests:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: ['3.8', '3.9', '3.10', '3.11']\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v2\n        with:\n          python-version: ${{ matrix.python-version }}\n      - uses: actions/cache@v2\n        env:\n          model-release: 1\n        id: spleeter-model-cache\n        with:\n          path: ${{ env.GITHUB_WORKSPACE }}/pretrained_models\n          key: models-${{ env.model-release }}\n          restore-keys: |\n            models-${{ env.model-release }}\n      - name: Install ffmpeg\n        run: |\n          sudo apt-get update && sudo apt-get install -y ffmpeg\n      - name: Install Poetry\n        run: |\n          pip install poetry\n          poetry config virtualenvs.in-project false\n          poetry config virtualenvs.path ~/.virtualenvs\n      - name: Cache Poetry virtualenv\n        uses: actions/cache@v1\n        id: cache\n        with:\n          path: ~/.virtualenvs\n          key: poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}\n          restore-keys: |\n            poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}\n      - name: Install Dependencies\n        run: poetry install\n        if: steps.cache.outputs.cache-hit != 'true'\n      - name: Code quality checks\n        run: |\n          poetry run black spleeter --check\n          poetry run isort spleeter --check\n      - name: Test with pytest\n        run: poetry run pytest tests/\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.vscode\n.DS_Store\n__pycache__\n**/reporting\n\npretrained_models\ndocs/build\n.vscode\nspleeter-feedstock/\n*FAKE_MUSDB_DIR\n\noutput/\nuseless_config.json\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog History\n\n## 2.4.2\n\nDependecy upgrades and adding support for python 3.11 (dropping 3.7)\n\n## 2.3.2\n\nRelease contrain on specific Tensorflow, numpy and Librosa versions\nDropping explicit support of python 3.6 but adding 3.10\n## 2.3.0\n\nUpdating dependencies to enable TensorFlow 2.5 support (and Python 3.9 overall)\nRemoving the destructor from the `Separator` class\n\n## 2.2.0\n\nMinor changes mainly fixing some issues:\n* mono training was not working due to hardcoded filters in the dataset\n* default argument of `separate` was of wrong type\n* added a way to request spleeter version with the `--version` argument in the CLI\n\n## 2.1.0\n\nThis version introduce design related changes, especially transition to Typer for CLI managment and Poetry as\nlibrary build backend.\n\n* `-i` option is now deprecated and replaced by traditional CLI input argument listing\n* Project is now built using Poetry\n* Project requires code formatting using Black and iSort\n* Dedicated GPU package `spleeter-gpu` is not supported anymore, `spleeter` package will support both CPU and GPU hardware\n\n### API changes:\n\n* function `get_default_audio_adapter` is now available as `default()` class method within `AudioAdapter` class\n* function `get_default_model_provider` is now available as `default()` class method within `ModelProvider` class\n* `STFTBackend` and `Codec` are now string enum\n* `GithubModelProvider` now use `httpx` with HTTP/2 support\n* Commands are now located in `__main__` module, wrapped as simple function using Typer options module provide specification for each available option and argument\n* `types` module provide custom type specification and must be enhanced in future release to provide more robust typing support with MyPy\n* `utils.logging` module has been cleaned, logger instance is now a module singleton, and a single function is used to configure it with verbose parameter\n* Added a custom logger handler (see tiangolo/typer#203 discussion)\n\n\n## 2.0\n\nFirst release, October 9th 2020\n\nTensorflow-2 compatible version, allowing uses in python 3.8.\n\n## 1.5.4\n\nFirst release, July 24th 2020\n\nAdd some padding of the input waveform to avoid separation artefacts on the edges due to unstabilities in the inverse fourier transforms.\nAlso add tests to ensure both librosa and tensorflow backends have same outputs.\n\n## 1.5.2\n\nFirst released, May 15th 2020\n\n### Major changes\n\n* PR #375 merged to avoid mutliple tf.graph instantiation failures\n\n### Minor changes\n\n* PR #362 use tf.abs instead of numpy\n* PR #352 tempdir cleaning\n\n\n## 1.5.1\n\nFirst released, April 15th 2020\n\n### Major changes\n\n* Bugfixes on the LibRosa STFT backend\n\n### Minor changes\n\n* Typos, and small bugfixes\n\n## 1.5.0\n\nFirst released, March 20th 2020\n\n### Major changes\n\n* Implement a new STFT backend using LibRosa, faster on CPU than TF implementation\n* Switch tensorflow version to 1.15.2\n\n### Minor changes\n\n* Typos, and small bugfixes\n\n## 1.4.9\n\nFirst released, Dec 27th 2019\n\n### Major changes\n\n* Add new configuration for processing until 16Khz\n\n### Minor changes\n\n* Typos, and small bugfixes\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019-present, Deezer SA.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "<img src=\"https://github.com/deezer/spleeter/raw/master/images/spleeter_logo.png\" height=\"80\" />\n\n[![Github actions](https://github.com/deezer/spleeter/workflows/pytest/badge.svg)](https://github.com/deezer/spleeter/actions) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/spleeter) [![PyPI version](https://badge.fury.io/py/spleeter.svg)](https://badge.fury.io/py/spleeter) [![Conda](https://img.shields.io/conda/vn/deezer-research/spleeter)](https://anaconda.org/deezer-research/spleeter) [![Docker Pulls](https://img.shields.io/docker/pulls/deezer/spleeter)](https://hub.docker.com/r/deezer/spleeter) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb) [![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/spleeter/community) [![status](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b/status.svg)](https://joss.theoj.org/papers/259e5efe669945a343bad6eccb89018b)\n\n> :warning: [Spleeter 2.1.0](https://pypi.org/project/spleeter/) release introduces some breaking changes, including new CLI option naming for input, and the drop\n> of dedicated GPU package. Please read [CHANGELOG](CHANGELOG.md) for more details.\n\n## About\n\n**Spleeter** is [Deezer](https://www.deezer.com/) source separation library with pretrained models\nwritten in [Python](https://www.python.org/) and uses [Tensorflow](https://tensorflow.org/). It makes it easy\nto train source separation model (assuming you have a dataset of isolated sources), and provides\nalready trained state of the art model for performing various flavour of separation :\n\n* Vocals (singing voice) / accompaniment separation ([2 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-2stems-model))\n* Vocals / drums / bass / other separation ([4 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-4stems-model))\n* Vocals / drums / bass / piano / other separation ([5 stems](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-5stems-model))\n\n2 stems and 4 stems models have [high performances](https://github.com/deezer/spleeter/wiki/Separation-Performances) on the [musdb](https://sigsep.github.io/datasets/musdb.html) dataset. **Spleeter** is also very fast as it can perform separation of audio files to 4 stems 100x faster than real-time when run on a GPU.\n\nWe designed **Spleeter** so you can use it straight from [command line](https://github.com/deezer/spleeter/wiki/2.-Getting-started#usage)\nas well as directly in your own development pipeline as a [Python library](https://github.com/deezer/spleeter/wiki/4.-API-Reference#separator). It can be installed with [pip](https://github.com/deezer/spleeter/wiki/1.-Installation#using-pip) or be used with\n[Docker](https://github.com/deezer/spleeter/wiki/2.-Getting-started#using-docker-image).\n\n### Projects and Softwares using **Spleeter**\n\nSince it's been released, there are multiple forks exposing **Spleeter** through either a Guided User Interface (GUI) or a standalone free or paying website. Please note that we do not host, maintain or directly support any of these initiatives.\n\nThat being said, many cool projects have been built on top of ours. Notably the porting to the *Ableton Live* ecosystem through the [Spleeter 4 Max](https://github.com/diracdeltas/spleeter4max#spleeter-for-max) project.\n\n**Spleeter** pre-trained models have also been used by professionnal audio softwares. Here's a non-exhaustive list:\n\n* [iZotope](https://www.izotope.com/en/shop/rx-8-standard.html) in its *Music Rebalance* feature within **RX 8**\n* [SpectralLayers](https://new.steinberg.net/spectralayers/) in its *Unmix* feature in **SpectralLayers 7**\n* [Acon Digital](https://acondigital.com/products/acoustica-audio-editor/) within **Acoustica 7**\n* [VirtualDJ](https://www.virtualdj.com/stems/) in their stem isolation feature\n* [Algoriddim](https://www.algoriddim.com/apps) in their **NeuralMix** and **djayPRO** app suite\n\n🆕 **Spleeter** is a baseline in the ongoing [Music Demixing Challenge](https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021)!\n\n## Spleeter Pro (Commercial version)\n\nCheck out our commercial version : [Spleeter Pro](https://www.deezer-techservices.com/solutions/spleeter/). Benefit from our expertise for precise audio separation, faster processing speeds, and dedicated professional support. \n\n## Quick start\n\nWant to try it out but don't want to install anything ? We have set up a [Google Colab](https://colab.research.google.com/github/deezer/spleeter/blob/master/spleeter.ipynb).\n\nReady to dig into it ? In a few lines you can install **Spleeter**  and separate the vocal and accompaniment parts from an example audio file.\nYou need first to install `ffmpeg` and `libsndfile`. It can be done on most platform using [Conda](https://github.com/deezer/spleeter/wiki/1.-Installation#using-conda):\n\n```bash\n# install dependencies using conda\nconda install -c conda-forge ffmpeg libsndfile\n# install spleeter with pip\npip install spleeter\n# download an example audio file (if you don't have wget, use another tool for downloading)\nwget https://github.com/deezer/spleeter/raw/master/audio_example.mp3\n# separate the example audio into two components\nspleeter separate -p spleeter:2stems -o output audio_example.mp3\n```\n\n> :warning: Note that we no longer recommend using `conda` for installing spleeter.\n\n> ⚠️ There are known issues with Apple M1 chips, mostly due to TensorFlow compatibility. Until these are fixed, you can use [this workaround](https://github.com/deezer/spleeter/issues/607#issuecomment-1021669444).\n\nYou should get two separated audio files (`vocals.wav` and `accompaniment.wav`) in the `output/audio_example` folder.\n\nFor a detailed documentation, please check the [repository wiki](https://github.com/deezer/spleeter/wiki/1.-Installation)\n\n## Development and Testing\n\nThis project is managed using [Poetry](https://python-poetry.org/docs/basic-usage/), to run test suite you\ncan execute the following set of commands:\n\n```bash\n# Clone spleeter repository\ngit clone https://github.com/Deezer/spleeter && cd spleeter\n# Install poetry\npip install poetry\n# Install spleeter dependencies\npoetry install\n# Run unit test suite\npoetry run pytest tests/\n```\n\n## Reference\n\n* Deezer Research - Source Separation Engine Story - deezer.io blog post:\n  * [English version](https://deezer.io/releasing-spleeter-deezer-r-d-source-separation-engine-2b88985e797e)\n  * [Japanese version](http://dzr.fm/splitterjp)\n* [Music Source Separation tool with pre-trained models / ISMIR2019 extended abstract](http://archives.ismir.net/ismir2019/latebreaking/000036.pdf)\n\nIf you use **Spleeter** in your work, please cite:\n\n```BibTeX\n@article{spleeter2020,\n  doi = {10.21105/joss.02154},\n  url = {https://doi.org/10.21105/joss.02154},\n  year = {2020},\n  publisher = {The Open Journal},\n  volume = {5},\n  number = {50},\n  pages = {2154},\n  author = {Romain Hennequin and Anis Khlif and Felix Voituret and Manuel Moussallam},\n  title = {Spleeter: a fast and efficient music source separation tool with pre-trained models},\n  journal = {Journal of Open Source Software},\n  note = {Deezer Research}\n}\n```\n\n## License\n\nThe code of **Spleeter** is [MIT-licensed](LICENSE).\n\n## Disclaimer\n\nIf you plan to use **Spleeter** on copyrighted material, make sure you get proper authorization from right owners beforehand.\n\n## Troubleshooting\n\n**Spleeter** is a complex piece of software and although we continously try to improve and test it you may encounter unexpected issues running it. If that's the case please check the [FAQ page](https://github.com/deezer/spleeter/wiki/5.-FAQ) first as well as the list of [currently open issues](https://github.com/deezer/spleeter/issues)\n\n### Windows users\n\n   It appears that sometimes the shortcut command `spleeter` does not work properly on windows. This is a known issue that we will hopefully fix soon. In the meantime replace `spleeter separate` by `python -m spleeter separate` in command line and it should work.\n\n## Contributing\n\nIf you would like to participate in the development of **Spleeter** you are more than welcome to do so. Don't hesitate to throw us a pull request and we'll do our best to examine it quickly. Please check out our [guidelines](.github/CONTRIBUTING.md) first.\n\n## Note\n\nThis repository include a demo audio file `audio_example.mp3` which is an excerpt\nfrom Slow Motion Dream by Steven M Bryant (c) copyright 2011 Licensed under a Creative\nCommons Attribution (3.0) [license](http://dig.ccmixter.org/files/stevieb357/34740)\nFt: CSoul,Alex Beroza & Robert Siekawitch\n"
  },
  {
    "path": "conda/spleeter/meta.yaml",
    "content": "{% set name = \"spleeter\" %}\n{% set version = \"2.4.0\" %}\n\npackage:\n  name: {{ name|lower }}\n  version: {{ version }}\n\nsource:\n  - url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz\n    sha256: 6cbe9e572474948515430804a22da255f774243aab77e58edb147566dbff7a42\n\nbuild:\n  number: 0\n  script: {{ PYTHON }} -m pip install . -vv\n  skip: True  # [osx]\n  entry_points:\n    - spleeter = spleeter.__main__:entrypoint\n\nrequirements:\n  host:\n    - python {{ python }}\n    - pip\n    - poetry\n  run:\n    - python {{ python }}\n    - tensorflow ==2.5.0  # [linux]\n    - tensorflow ==2.5.0  # [win]\n    - numpy\n    - pandas\n    - ffmpeg-python\n    - norbert\n    - typer\n    - httpx\n\ntest:\n  imports:\n    - spleeter\n    - spleeter.model\n    - spleeter.utils\n    - spleeter.separator\n\nabout:\n  home: https://github.com/deezer/spleeter\n  license: MIT\n  license_family: MIT\n  license_file: LICENSE\n  summary: The Deezer source separation library with pretrained models based on tensorflow.\n  doc_url: https://github.com/deezer/spleeter/wiki\n  dev_url: https://github.com/deezer/spleeter\n\nextra:\n  recipe-maintainers:\n    - Faylixe\n    - romi1502\n"
  },
  {
    "path": "configs/2stems/base_config.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"2stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"accompaniment\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 1000000,\n    \"throttle_secs\":300,\n    \"random_seed\":0,\n    \"save_checkpoints_steps\":150,\n    \"save_summary_steps\":5,\n    \"model\":{\n            \"type\":\"unet.unet\",\n            \"params\":{}\n            }\n}\n"
  },
  {
    "path": "configs/4stems/base_config.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"4stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 1500000,\n    \"throttle_secs\":600,\n    \"random_seed\":3,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "configs/5stems/base_config.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"5stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"piano\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 2500000,\n    \"throttle_secs\":600,\n    \"random_seed\":8,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.softmax_unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "configs/musdb_config.json",
    "content": "{\n    \"train_csv\": \"configs/musdb_train.csv\",\n    \"validation_csv\": \"configs/musdb_validation.csv\",\n    \"model_dir\": \"musdb_model\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"n_chunks_per_song\":40,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"cache/training\",\n    \"validation_cache\":\"cache/validation\",\n    \"train_max_steps\": 200000,\n    \"throttle_secs\":1800,\n    \"random_seed\":3,\n    \"save_checkpoints_steps\":1000,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "configs/musdb_train.csv",
    "content": "mix_path,vocals_path,drums_path,bass_path,other_path,duration\ntrain/A Classic Education - NightOwl/mixture.wav,train/A Classic Education - NightOwl/vocals.wav,train/A Classic Education - NightOwl/drums.wav,train/A Classic Education - NightOwl/bass.wav,train/A Classic Education - NightOwl/other.wav,171.247166\ntrain/ANiMAL - Clinic A/mixture.wav,train/ANiMAL - Clinic A/vocals.wav,train/ANiMAL - Clinic A/drums.wav,train/ANiMAL - Clinic A/bass.wav,train/ANiMAL - Clinic A/other.wav,237.865215\ntrain/ANiMAL - Easy Tiger/mixture.wav,train/ANiMAL - Easy Tiger/vocals.wav,train/ANiMAL - Easy Tiger/drums.wav,train/ANiMAL - Easy Tiger/bass.wav,train/ANiMAL - Easy Tiger/other.wav,205.473379\ntrain/Actions - Devil's Words/mixture.wav,train/Actions - Devil's Words/vocals.wav,train/Actions - Devil's Words/drums.wav,train/Actions - Devil's Words/bass.wav,train/Actions - Devil's Words/other.wav,196.626576\ntrain/Actions - South Of The Water/mixture.wav,train/Actions - South Of The Water/vocals.wav,train/Actions - South Of The Water/drums.wav,train/Actions - South Of The Water/bass.wav,train/Actions - South Of The Water/other.wav,176.610975\ntrain/Aimee Norwich - Child/mixture.wav,train/Aimee Norwich - Child/vocals.wav,train/Aimee Norwich - Child/drums.wav,train/Aimee Norwich - Child/bass.wav,train/Aimee Norwich - Child/other.wav,189.080091\ntrain/Alexander Ross - Velvet Curtain/mixture.wav,train/Alexander Ross - Velvet Curtain/vocals.wav,train/Alexander Ross - Velvet Curtain/drums.wav,train/Alexander Ross - Velvet Curtain/bass.wav,train/Alexander Ross - Velvet Curtain/other.wav,514.298776\ntrain/Angela Thomas Wade - Milk Cow Blues/mixture.wav,train/Angela Thomas Wade - Milk Cow Blues/vocals.wav,train/Angela Thomas Wade - Milk Cow Blues/drums.wav,train/Angela Thomas Wade - Milk Cow Blues/bass.wav,train/Angela Thomas Wade - Milk Cow Blues/other.wav,210.906848\ntrain/Atlantis Bound - It Was My Fault For Waiting/mixture.wav,train/Atlantis Bound - It Was My Fault For Waiting/vocals.wav,train/Atlantis Bound - It Was My Fault For Waiting/drums.wav,train/Atlantis Bound - It Was My Fault For Waiting/bass.wav,train/Atlantis Bound - It Was My Fault For Waiting/other.wav,268.051156\ntrain/Auctioneer - Our Future Faces/mixture.wav,train/Auctioneer - Our Future Faces/vocals.wav,train/Auctioneer - Our Future Faces/drums.wav,train/Auctioneer - Our Future Faces/bass.wav,train/Auctioneer - Our Future Faces/other.wav,207.702494\ntrain/AvaLuna - Waterduct/mixture.wav,train/AvaLuna - Waterduct/vocals.wav,train/AvaLuna - Waterduct/drums.wav,train/AvaLuna - Waterduct/bass.wav,train/AvaLuna - Waterduct/other.wav,259.111474\ntrain/BigTroubles - Phantom/mixture.wav,train/BigTroubles - Phantom/vocals.wav,train/BigTroubles - Phantom/drums.wav,train/BigTroubles - Phantom/bass.wav,train/BigTroubles - Phantom/other.wav,146.750113\ntrain/Bill Chudziak - Children Of No-one/mixture.wav,train/Bill Chudziak - Children Of No-one/vocals.wav,train/Bill Chudziak - Children Of No-one/drums.wav,train/Bill Chudziak - Children Of No-one/bass.wav,train/Bill Chudziak - Children Of No-one/other.wav,230.736689\ntrain/Black Bloc - If You Want Success/mixture.wav,train/Black Bloc - If You Want Success/vocals.wav,train/Black Bloc - If You Want Success/drums.wav,train/Black Bloc - If You Want Success/bass.wav,train/Black Bloc - If You Want Success/other.wav,398.547302\ntrain/Celestial Shore - Die For Us/mixture.wav,train/Celestial Shore - Die For Us/vocals.wav,train/Celestial Shore - Die For Us/drums.wav,train/Celestial Shore - Die For Us/bass.wav,train/Celestial Shore - Die For Us/other.wav,278.476916\ntrain/Chris Durban - Celebrate/mixture.wav,train/Chris Durban - Celebrate/vocals.wav,train/Chris Durban - Celebrate/drums.wav,train/Chris Durban - Celebrate/bass.wav,train/Chris Durban - Celebrate/other.wav,301.603991\ntrain/Clara Berry And Wooldog - Air Traffic/mixture.wav,train/Clara Berry And Wooldog - Air Traffic/vocals.wav,train/Clara Berry And Wooldog - Air Traffic/drums.wav,train/Clara Berry And Wooldog - Air Traffic/bass.wav,train/Clara Berry And Wooldog - Air Traffic/other.wav,173.267302\ntrain/Clara Berry And Wooldog - Stella/mixture.wav,train/Clara Berry And Wooldog - Stella/vocals.wav,train/Clara Berry And Wooldog - Stella/drums.wav,train/Clara Berry And Wooldog - Stella/bass.wav,train/Clara Berry And Wooldog - Stella/other.wav,195.558458\ntrain/Cnoc An Tursa - Bannockburn/mixture.wav,train/Cnoc An Tursa - Bannockburn/vocals.wav,train/Cnoc An Tursa - Bannockburn/drums.wav,train/Cnoc An Tursa - Bannockburn/bass.wav,train/Cnoc An Tursa - Bannockburn/other.wav,294.521905\ntrain/Creepoid - OldTree/mixture.wav,train/Creepoid - OldTree/vocals.wav,train/Creepoid - OldTree/drums.wav,train/Creepoid - OldTree/bass.wav,train/Creepoid - OldTree/other.wav,302.02195\ntrain/Dark Ride - Burning Bridges/mixture.wav,train/Dark Ride - Burning Bridges/vocals.wav,train/Dark Ride - Burning Bridges/drums.wav,train/Dark Ride - Burning Bridges/bass.wav,train/Dark Ride - Burning Bridges/other.wav,232.663946\ntrain/Dreamers Of The Ghetto - Heavy Love/mixture.wav,train/Dreamers Of The Ghetto - Heavy Love/vocals.wav,train/Dreamers Of The Ghetto - Heavy Love/drums.wav,train/Dreamers Of The Ghetto - Heavy Love/bass.wav,train/Dreamers Of The Ghetto - Heavy Love/other.wav,294.800544\ntrain/Drumtracks - Ghost Bitch/mixture.wav,train/Drumtracks - Ghost Bitch/vocals.wav,train/Drumtracks - Ghost Bitch/drums.wav,train/Drumtracks - Ghost Bitch/bass.wav,train/Drumtracks - Ghost Bitch/other.wav,356.913923\ntrain/Faces On Film - Waiting For Ga/mixture.wav,train/Faces On Film - Waiting For Ga/vocals.wav,train/Faces On Film - Waiting For Ga/drums.wav,train/Faces On Film - Waiting For Ga/bass.wav,train/Faces On Film - Waiting For Ga/other.wav,257.439637\ntrain/Fergessen - Back From The Start/mixture.wav,train/Fergessen - Back From The Start/vocals.wav,train/Fergessen - Back From The Start/drums.wav,train/Fergessen - Back From The Start/bass.wav,train/Fergessen - Back From The Start/other.wav,168.553651\ntrain/Fergessen - The Wind/mixture.wav,train/Fergessen - The Wind/vocals.wav,train/Fergessen - The Wind/drums.wav,train/Fergessen - The Wind/bass.wav,train/Fergessen - The Wind/other.wav,191.820045\ntrain/Flags - 54/mixture.wav,train/Flags - 54/vocals.wav,train/Flags - 54/drums.wav,train/Flags - 54/bass.wav,train/Flags - 54/other.wav,315.164444\ntrain/Giselle - Moss/mixture.wav,train/Giselle - Moss/vocals.wav,train/Giselle - Moss/drums.wav,train/Giselle - Moss/bass.wav,train/Giselle - Moss/other.wav,201.711746\ntrain/Grants - PunchDrunk/mixture.wav,train/Grants - PunchDrunk/vocals.wav,train/Grants - PunchDrunk/drums.wav,train/Grants - PunchDrunk/bass.wav,train/Grants - PunchDrunk/other.wav,204.405261\ntrain/Helado Negro - Mitad Del Mundo/mixture.wav,train/Helado Negro - Mitad Del Mundo/vocals.wav,train/Helado Negro - Mitad Del Mundo/drums.wav,train/Helado Negro - Mitad Del Mundo/bass.wav,train/Helado Negro - Mitad Del Mundo/other.wav,181.672925\ntrain/Hezekiah Jones - Borrowed Heart/mixture.wav,train/Hezekiah Jones - Borrowed Heart/vocals.wav,train/Hezekiah Jones - Borrowed Heart/drums.wav,train/Hezekiah Jones - Borrowed Heart/bass.wav,train/Hezekiah Jones - Borrowed Heart/other.wav,241.394649\ntrain/Hollow Ground - Left Blind/mixture.wav,train/Hollow Ground - Left Blind/vocals.wav,train/Hollow Ground - Left Blind/drums.wav,train/Hollow Ground - Left Blind/bass.wav,train/Hollow Ground - Left Blind/other.wav,159.103129\ntrain/Hop Along - Sister Cities/mixture.wav,train/Hop Along - Sister Cities/vocals.wav,train/Hop Along - Sister Cities/drums.wav,train/Hop Along - Sister Cities/bass.wav,train/Hop Along - Sister Cities/other.wav,283.237007\ntrain/Invisible Familiars - Disturbing Wildlife/mixture.wav,train/Invisible Familiars - Disturbing Wildlife/vocals.wav,train/Invisible Familiars - Disturbing Wildlife/drums.wav,train/Invisible Familiars - Disturbing Wildlife/bass.wav,train/Invisible Familiars - Disturbing Wildlife/other.wav,218.499773\ntrain/James May - All Souls Moon/mixture.wav,train/James May - All Souls Moon/vocals.wav,train/James May - All Souls Moon/drums.wav,train/James May - All Souls Moon/bass.wav,train/James May - All Souls Moon/other.wav,220.844989\ntrain/James May - Dont Let Go/mixture.wav,train/James May - Dont Let Go/vocals.wav,train/James May - Dont Let Go/drums.wav,train/James May - Dont Let Go/bass.wav,train/James May - Dont Let Go/other.wav,241.951927\ntrain/James May - If You Say/mixture.wav,train/James May - If You Say/vocals.wav,train/James May - If You Say/drums.wav,train/James May - If You Say/bass.wav,train/James May - If You Say/other.wav,258.321995\ntrain/Jay Menon - Through My Eyes/mixture.wav,train/Jay Menon - Through My Eyes/vocals.wav,train/Jay Menon - Through My Eyes/drums.wav,train/Jay Menon - Through My Eyes/bass.wav,train/Jay Menon - Through My Eyes/other.wav,253.167166\ntrain/Johnny Lokke - Whisper To A Scream/mixture.wav,train/Johnny Lokke - Whisper To A Scream/vocals.wav,train/Johnny Lokke - Whisper To A Scream/drums.wav,train/Johnny Lokke - Whisper To A Scream/bass.wav,train/Johnny Lokke - Whisper To A Scream/other.wav,255.326621\n\"train/Jokers, Jacks & Kings - Sea Of Leaves/mixture.wav\",\"train/Jokers, Jacks & Kings - Sea Of Leaves/vocals.wav\",\"train/Jokers, Jacks & Kings - Sea Of Leaves/drums.wav\",\"train/Jokers, Jacks & Kings - Sea Of Leaves/bass.wav\",\"train/Jokers, Jacks & Kings - Sea Of Leaves/other.wav\",191.471746\ntrain/Leaf - Come Around/mixture.wav,train/Leaf - Come Around/vocals.wav,train/Leaf - Come Around/drums.wav,train/Leaf - Come Around/bass.wav,train/Leaf - Come Around/other.wav,264.382404\ntrain/Leaf - Wicked/mixture.wav,train/Leaf - Wicked/vocals.wav,train/Leaf - Wicked/drums.wav,train/Leaf - Wicked/bass.wav,train/Leaf - Wicked/other.wav,190.635828\ntrain/Lushlife - Toynbee Suite/mixture.wav,train/Lushlife - Toynbee Suite/vocals.wav,train/Lushlife - Toynbee Suite/drums.wav,train/Lushlife - Toynbee Suite/bass.wav,train/Lushlife - Toynbee Suite/other.wav,628.378413\ntrain/Matthew Entwistle - Dont You Ever/mixture.wav,train/Matthew Entwistle - Dont You Ever/vocals.wav,train/Matthew Entwistle - Dont You Ever/drums.wav,train/Matthew Entwistle - Dont You Ever/bass.wav,train/Matthew Entwistle - Dont You Ever/other.wav,113.824218\ntrain/Meaxic - You Listen/mixture.wav,train/Meaxic - You Listen/vocals.wav,train/Meaxic - You Listen/drums.wav,train/Meaxic - You Listen/bass.wav,train/Meaxic - You Listen/other.wav,412.525714\ntrain/Music Delta - 80s Rock/mixture.wav,train/Music Delta - 80s Rock/vocals.wav,train/Music Delta - 80s Rock/drums.wav,train/Music Delta - 80s Rock/bass.wav,train/Music Delta - 80s Rock/other.wav,36.733968\ntrain/Music Delta - Beatles/mixture.wav,train/Music Delta - Beatles/vocals.wav,train/Music Delta - Beatles/drums.wav,train/Music Delta - Beatles/bass.wav,train/Music Delta - Beatles/other.wav,36.176689\ntrain/Music Delta - Britpop/mixture.wav,train/Music Delta - Britpop/vocals.wav,train/Music Delta - Britpop/drums.wav,train/Music Delta - Britpop/bass.wav,train/Music Delta - Britpop/other.wav,36.594649\ntrain/Music Delta - Country1/mixture.wav,train/Music Delta - Country1/vocals.wav,train/Music Delta - Country1/drums.wav,train/Music Delta - Country1/bass.wav,train/Music Delta - Country1/other.wav,34.551293\ntrain/Music Delta - Country2/mixture.wav,train/Music Delta - Country2/vocals.wav,train/Music Delta - Country2/drums.wav,train/Music Delta - Country2/bass.wav,train/Music Delta - Country2/other.wav,17.275646\ntrain/Music Delta - Disco/mixture.wav,train/Music Delta - Disco/vocals.wav,train/Music Delta - Disco/drums.wav,train/Music Delta - Disco/bass.wav,train/Music Delta - Disco/other.wav,124.598277\ntrain/Music Delta - Gospel/mixture.wav,train/Music Delta - Gospel/vocals.wav,train/Music Delta - Gospel/drums.wav,train/Music Delta - Gospel/bass.wav,train/Music Delta - Gospel/other.wav,75.557732\ntrain/Music Delta - Grunge/mixture.wav,train/Music Delta - Grunge/vocals.wav,train/Music Delta - Grunge/drums.wav,train/Music Delta - Grunge/bass.wav,train/Music Delta - Grunge/other.wav,41.656599\ntrain/Music Delta - Hendrix/mixture.wav,train/Music Delta - Hendrix/vocals.wav,train/Music Delta - Hendrix/drums.wav,train/Music Delta - Hendrix/bass.wav,train/Music Delta - Hendrix/other.wav,19.644082\ntrain/Music Delta - Punk/mixture.wav,train/Music Delta - Punk/vocals.wav,train/Music Delta - Punk/drums.wav,train/Music Delta - Punk/bass.wav,train/Music Delta - Punk/other.wav,28.583764\ntrain/Music Delta - Reggae/mixture.wav,train/Music Delta - Reggae/vocals.wav,train/Music Delta - Reggae/drums.wav,train/Music Delta - Reggae/bass.wav,train/Music Delta - Reggae/other.wav,17.275646\ntrain/Music Delta - Rock/mixture.wav,train/Music Delta - Rock/vocals.wav,train/Music Delta - Rock/drums.wav,train/Music Delta - Rock/bass.wav,train/Music Delta - Rock/other.wav,12.910295\ntrain/Music Delta - Rockabilly/mixture.wav,train/Music Delta - Rockabilly/vocals.wav,train/Music Delta - Rockabilly/drums.wav,train/Music Delta - Rockabilly/bass.wav,train/Music Delta - Rockabilly/other.wav,25.75093\ntrain/Night Panther - Fire/mixture.wav,train/Night Panther - Fire/vocals.wav,train/Night Panther - Fire/drums.wav,train/Night Panther - Fire/bass.wav,train/Night Panther - Fire/other.wav,212.810884\ntrain/North To Alaska - All The Same/mixture.wav,train/North To Alaska - All The Same/vocals.wav,train/North To Alaska - All The Same/drums.wav,train/North To Alaska - All The Same/bass.wav,train/North To Alaska - All The Same/other.wav,247.965896\ntrain/Patrick Talbot - Set Me Free/mixture.wav,train/Patrick Talbot - Set Me Free/vocals.wav,train/Patrick Talbot - Set Me Free/drums.wav,train/Patrick Talbot - Set Me Free/bass.wav,train/Patrick Talbot - Set Me Free/other.wav,289.785034\ntrain/Phre The Eon - Everybody's Falling Apart/mixture.wav,train/Phre The Eon - Everybody's Falling Apart/vocals.wav,train/Phre The Eon - Everybody's Falling Apart/drums.wav,train/Phre The Eon - Everybody's Falling Apart/bass.wav,train/Phre The Eon - Everybody's Falling Apart/other.wav,224.235102\ntrain/Port St Willow - Stay Even/mixture.wav,train/Port St Willow - Stay Even/vocals.wav,train/Port St Willow - Stay Even/drums.wav,train/Port St Willow - Stay Even/bass.wav,train/Port St Willow - Stay Even/other.wav,316.836281\ntrain/Remember December - C U Next Time/mixture.wav,train/Remember December - C U Next Time/vocals.wav,train/Remember December - C U Next Time/drums.wav,train/Remember December - C U Next Time/bass.wav,train/Remember December - C U Next Time/other.wav,242.532426\ntrain/Secret Mountains - High Horse/mixture.wav,train/Secret Mountains - High Horse/vocals.wav,train/Secret Mountains - High Horse/drums.wav,train/Secret Mountains - High Horse/bass.wav,train/Secret Mountains - High Horse/other.wav,355.311746\ntrain/Skelpolu - Together Alone/mixture.wav,train/Skelpolu - Together Alone/vocals.wav,train/Skelpolu - Together Alone/drums.wav,train/Skelpolu - Together Alone/bass.wav,train/Skelpolu - Together Alone/other.wav,325.822404\ntrain/Snowmine - Curfews/mixture.wav,train/Snowmine - Curfews/vocals.wav,train/Snowmine - Curfews/drums.wav,train/Snowmine - Curfews/bass.wav,train/Snowmine - Curfews/other.wav,275.017143\ntrain/Spike Mullings - Mike's Sulking/mixture.wav,train/Spike Mullings - Mike's Sulking/vocals.wav,train/Spike Mullings - Mike's Sulking/drums.wav,train/Spike Mullings - Mike's Sulking/bass.wav,train/Spike Mullings - Mike's Sulking/other.wav,256.696599\ntrain/St Vitus - Word Gets Around/mixture.wav,train/St Vitus - Word Gets Around/vocals.wav,train/St Vitus - Word Gets Around/drums.wav,train/St Vitus - Word Gets Around/bass.wav,train/St Vitus - Word Gets Around/other.wav,247.013878\ntrain/Steven Clark - Bounty/mixture.wav,train/Steven Clark - Bounty/vocals.wav,train/Steven Clark - Bounty/drums.wav,train/Steven Clark - Bounty/bass.wav,train/Steven Clark - Bounty/other.wav,289.274195\ntrain/Strand Of Oaks - Spacestation/mixture.wav,train/Strand Of Oaks - Spacestation/vocals.wav,train/Strand Of Oaks - Spacestation/drums.wav,train/Strand Of Oaks - Spacestation/bass.wav,train/Strand Of Oaks - Spacestation/other.wav,243.670204\ntrain/Sweet Lights - You Let Me Down/mixture.wav,train/Sweet Lights - You Let Me Down/vocals.wav,train/Sweet Lights - You Let Me Down/drums.wav,train/Sweet Lights - You Let Me Down/bass.wav,train/Sweet Lights - You Let Me Down/other.wav,391.790295\ntrain/Swinging Steaks - Lost My Way/mixture.wav,train/Swinging Steaks - Lost My Way/vocals.wav,train/Swinging Steaks - Lost My Way/drums.wav,train/Swinging Steaks - Lost My Way/bass.wav,train/Swinging Steaks - Lost My Way/other.wav,309.963175\ntrain/The Districts - Vermont/mixture.wav,train/The Districts - Vermont/vocals.wav,train/The Districts - Vermont/drums.wav,train/The Districts - Vermont/bass.wav,train/The Districts - Vermont/other.wav,227.973515\ntrain/The Long Wait - Back Home To Blue/mixture.wav,train/The Long Wait - Back Home To Blue/vocals.wav,train/The Long Wait - Back Home To Blue/drums.wav,train/The Long Wait - Back Home To Blue/bass.wav,train/The Long Wait - Back Home To Blue/other.wav,260.458231\ntrain/The Scarlet Brand - Les Fleurs Du Mal/mixture.wav,train/The Scarlet Brand - Les Fleurs Du Mal/vocals.wav,train/The Scarlet Brand - Les Fleurs Du Mal/drums.wav,train/The Scarlet Brand - Les Fleurs Du Mal/bass.wav,train/The Scarlet Brand - Les Fleurs Du Mal/other.wav,303.438367\ntrain/The So So Glos - Emergency/mixture.wav,train/The So So Glos - Emergency/vocals.wav,train/The So So Glos - Emergency/drums.wav,train/The So So Glos - Emergency/bass.wav,train/The So So Glos - Emergency/other.wav,166.812154\ntrain/The Wrong'Uns - Rothko/mixture.wav,train/The Wrong'Uns - Rothko/vocals.wav,train/The Wrong'Uns - Rothko/drums.wav,train/The Wrong'Uns - Rothko/bass.wav,train/The Wrong'Uns - Rothko/other.wav,202.152925\ntrain/Tim Taler - Stalker/mixture.wav,train/Tim Taler - Stalker/vocals.wav,train/Tim Taler - Stalker/drums.wav,train/Tim Taler - Stalker/bass.wav,train/Tim Taler - Stalker/other.wav,237.633016\ntrain/Titanium - Haunted Age/mixture.wav,train/Titanium - Haunted Age/vocals.wav,train/Titanium - Haunted Age/drums.wav,train/Titanium - Haunted Age/bass.wav,train/Titanium - Haunted Age/other.wav,248.105215\ntrain/Traffic Experiment - Once More (With Feeling)/mixture.wav,train/Traffic Experiment - Once More (With Feeling)/vocals.wav,train/Traffic Experiment - Once More (With Feeling)/drums.wav,train/Traffic Experiment - Once More (With Feeling)/bass.wav,train/Traffic Experiment - Once More (With Feeling)/other.wav,435.07229\ntrain/Triviul - Dorothy/mixture.wav,train/Triviul - Dorothy/vocals.wav,train/Triviul - Dorothy/drums.wav,train/Triviul - Dorothy/bass.wav,train/Triviul - Dorothy/other.wav,187.361814\ntrain/Voelund - Comfort Lives In Belief/mixture.wav,train/Voelund - Comfort Lives In Belief/vocals.wav,train/Voelund - Comfort Lives In Belief/drums.wav,train/Voelund - Comfort Lives In Belief/bass.wav,train/Voelund - Comfort Lives In Belief/other.wav,209.90839\ntrain/Wall Of Death - Femme/mixture.wav,train/Wall Of Death - Femme/vocals.wav,train/Wall Of Death - Femme/drums.wav,train/Wall Of Death - Femme/bass.wav,train/Wall Of Death - Femme/other.wav,238.933333\ntrain/Young Griffo - Blood To Bone/mixture.wav,train/Young Griffo - Blood To Bone/vocals.wav,train/Young Griffo - Blood To Bone/drums.wav,train/Young Griffo - Blood To Bone/bass.wav,train/Young Griffo - Blood To Bone/other.wav,254.397823\ntrain/Young Griffo - Facade/mixture.wav,train/Young Griffo - Facade/vocals.wav,train/Young Griffo - Facade/drums.wav,train/Young Griffo - Facade/bass.wav,train/Young Griffo - Facade/other.wav,167.857052\n"
  },
  {
    "path": "configs/musdb_validation.csv",
    "content": "mix_path,vocals_path,drums_path,bass_path,other_path,duration\ntrain/ANiMAL - Rockshow/mixture.wav,train/ANiMAL - Rockshow/vocals.wav,train/ANiMAL - Rockshow/drums.wav,train/ANiMAL - Rockshow/bass.wav,train/ANiMAL - Rockshow/other.wav,165.511837\ntrain/Actions - One Minute Smile/mixture.wav,train/Actions - One Minute Smile/vocals.wav,train/Actions - One Minute Smile/drums.wav,train/Actions - One Minute Smile/bass.wav,train/Actions - One Minute Smile/other.wav,163.375601\ntrain/Alexander Ross - Goodbye Bolero/mixture.wav,train/Alexander Ross - Goodbye Bolero/vocals.wav,train/Alexander Ross - Goodbye Bolero/drums.wav,train/Alexander Ross - Goodbye Bolero/bass.wav,train/Alexander Ross - Goodbye Bolero/other.wav,418.632562\ntrain/Clara Berry And Wooldog - Waltz For My Victims/mixture.wav,train/Clara Berry And Wooldog - Waltz For My Victims/vocals.wav,train/Clara Berry And Wooldog - Waltz For My Victims/drums.wav,train/Clara Berry And Wooldog - Waltz For My Victims/bass.wav,train/Clara Berry And Wooldog - Waltz For My Victims/other.wav,175.240998\ntrain/Fergessen - Nos Palpitants/mixture.wav,train/Fergessen - Nos Palpitants/vocals.wav,train/Fergessen - Nos Palpitants/drums.wav,train/Fergessen - Nos Palpitants/bass.wav,train/Fergessen - Nos Palpitants/other.wav,198.228753\ntrain/James May - On The Line/mixture.wav,train/James May - On The Line/vocals.wav,train/James May - On The Line/drums.wav,train/James May - On The Line/bass.wav,train/James May - On The Line/other.wav,256.09288\ntrain/Johnny Lokke - Promises & Lies/mixture.wav,train/Johnny Lokke - Promises & Lies/vocals.wav,train/Johnny Lokke - Promises & Lies/drums.wav,train/Johnny Lokke - Promises & Lies/bass.wav,train/Johnny Lokke - Promises & Lies/other.wav,285.814422\ntrain/Leaf - Summerghost/mixture.wav,train/Leaf - Summerghost/vocals.wav,train/Leaf - Summerghost/drums.wav,train/Leaf - Summerghost/bass.wav,train/Leaf - Summerghost/other.wav,231.804807\ntrain/Meaxic - Take A Step/mixture.wav,train/Meaxic - Take A Step/vocals.wav,train/Meaxic - Take A Step/drums.wav,train/Meaxic - Take A Step/bass.wav,train/Meaxic - Take A Step/other.wav,282.517188\ntrain/Patrick Talbot - A Reason To Leave/mixture.wav,train/Patrick Talbot - A Reason To Leave/vocals.wav,train/Patrick Talbot - A Reason To Leave/drums.wav,train/Patrick Talbot - A Reason To Leave/bass.wav,train/Patrick Talbot - A Reason To Leave/other.wav,259.552653\ntrain/Skelpolu - Human Mistakes/mixture.wav,train/Skelpolu - Human Mistakes/vocals.wav,train/Skelpolu - Human Mistakes/drums.wav,train/Skelpolu - Human Mistakes/bass.wav,train/Skelpolu - Human Mistakes/other.wav,324.498866\ntrain/Traffic Experiment - Sirens/mixture.wav,train/Traffic Experiment - Sirens/vocals.wav,train/Traffic Experiment - Sirens/drums.wav,train/Traffic Experiment - Sirens/bass.wav,train/Traffic Experiment - Sirens/other.wav,421.279637\ntrain/Triviul - Angelsaint/mixture.wav,train/Triviul - Angelsaint/vocals.wav,train/Triviul - Angelsaint/drums.wav,train/Triviul - Angelsaint/bass.wav,train/Triviul - Angelsaint/other.wav,236.704218\ntrain/Young Griffo - Pennies/mixture.wav,train/Young Griffo - Pennies/vocals.wav,train/Young Griffo - Pennies/drums.wav,train/Young Griffo - Pennies/bass.wav,train/Young Griffo - Pennies/other.wav,277.803537\n"
  },
  {
    "path": "docker/conda-entrypoint.sh",
    "content": "#!/bin/bash\n\n######################################################################\n# Custom entrypoint that activate conda before running spleeter.\n#\n# @author Félix Voituret <fvoituret@deezer.com>\n# @version 1.0.0\n######################################################################\n\n# shellcheck disable=1091\n. \"/opt/conda/etc/profile.d/conda.sh\"\nconda activate base\nspleeter \"$@\""
  },
  {
    "path": "docker/conda.dockerfile",
    "content": "ARG BASE=python:3.7\nFROM ${BASE}\n\nRUN apt-get update --fix-missing \\\n    && apt-get install -y wget bzip2 ca-certificates curl git \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-4.6.14-Linux-x86_64.sh -O ~/miniconda.sh \\\n    && /bin/bash ~/miniconda.sh -b -p /opt/conda \\\n    && rm ~/miniconda.sh \\\n    && /opt/conda/bin/conda clean -tipsy \\\n    && ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh \\\n    && echo \". /opt/conda/etc/profile.d/conda.sh\" >> ~/.bashrc \\\n    && echo \"conda activate base\" >> ~/.bashrc \\\n    && ln -s /opt/conda/bin/conda /usr/bin/conda\nSHELL [\"/bin/bash\", \"-c\"]"
  },
  {
    "path": "docker/cuda-10-0.dockerfile",
    "content": "ARG BASE=python:3.7\nFROM ${BASE}\n\nENV CUDA_VERSION 10.0.130\nENV CUDA_PKG_VERSION 10-0=$CUDA_VERSION-1\nENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}\nENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n\nENV NVIDIA_VISIBLE_DEVICES=all\nENV NVIDIA_DRIVER_CAPABILITIES compute,utility\nENV NVIDIA_REQUIRE_CUDA \"cuda>=10.0 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=410,driver<411\"\nENV NCCL_VERSION 2.4.2\nENV CUDNN_VERSION 7.6.0.64\n\nLABEL com.nvidia.cuda.version=\"${CUDA_VERSION}\"\nLABEL com.nvidia.cudnn.version=\"${CUDNN_VERSION}\"\nLABEL com.nvidia.volumes.needed=\"nvidia_driver\"\n\nRUN apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        gnupg2 \\\n        curl \\\n        ca-certificates \\\n    && curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - \\\n    && echo \"deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /\" > /etc/apt/sources.list.d/cuda.list \\\n    && echo \"deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /\" > /etc/apt/sources.list.d/nvidia-ml.list \\\n    && apt-get purge --autoremove -y curl \\\n    && apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        cuda-cudart-$CUDA_PKG_VERSION \\\n        cuda-compat-10-0 \\\n    && ln -s cuda-10.0 /usr/local/cuda \\\n    && echo \"/usr/local/nvidia/lib\" >> /etc/ld.so.conf.d/nvidia.conf \\\n    && echo \"/usr/local/nvidia/lib64\" >> /etc/ld.so.conf.d/nvidia.conf \\\n    && apt-get install -y --no-install-recommends \\\n        cuda-toolkit-10-0 \\\n        cuda-libraries-$CUDA_PKG_VERSION \\\n        cuda-nvtx-$CUDA_PKG_VERSION \\\n        libnccl2=$NCCL_VERSION-1+cuda10.0 \\\n        libcudnn7=$CUDNN_VERSION-1+cuda10.0 \\\n    && apt-mark hold libnccl2 \\\n    && apt-mark hold libcudnn7 \\\n    && rm -rf /var/lib/apt/lists/*\n"
  },
  {
    "path": "docker/cuda-10-1.dockerfile",
    "content": "ARG BASE=python:3.8\nFROM ${BASE}\n\nENV CUDA_VERSION 10.1.243\nENV CUDA_PKG_VERSION 10-1=$CUDA_VERSION-1\nENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}\nENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n\nENV NVIDIA_VISIBLE_DEVICES all\nENV NVIDIA_DRIVER_CAPABILITIES compute,utility\nENV NVIDIA_REQUIRE_CUDA \"cuda>=10.1 brand=tesla,driver>=396,driver<397 brand=tesla,driver>=410,driver<411 brand=tesla,driver>=418,driver<419\"\nENV CUDNN_VERSION 7.6.5.32\nENV NCCL_VERSION 2.7.8\n\nLABEL com.nvidia.cuda.version=\"${CUDA_VERSION}\"\nLABEL com.nvidia.cudnn.version=\"${CUDNN_VERSION}\"\nLABEL com.nvidia.volumes.needed=\"nvidia_driver\"\n\nRUN apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        gnupg2 \\\n        curl \\\n        ca-certificates \\\n    && curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - \\\n    && echo \"deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /\" > /etc/apt/sources.list.d/cuda.list \\\n    && echo \"deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /\" > /etc/apt/sources.list.d/nvidia-ml.list \\\n    && apt-get purge --autoremove -y curl \\\n    && apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        cuda-cudart-$CUDA_PKG_VERSION \\\n        cuda-compat-10-1 \\\n    && ln -s cuda-10.1 /usr/local/cuda \\\n    && echo \"/usr/local/nvidia/lib\" >> /etc/ld.so.conf.d/nvidia.conf \\\n    && echo \"/usr/local/nvidia/lib64\" >> /etc/ld.so.conf.d/nvidia.conf \\\n    && apt-get install -y --no-install-recommends \\\n        cuda-libraries-$CUDA_PKG_VERSION \\\n        cuda-npp-$CUDA_PKG_VERSION \\\n        cuda-nvtx-$CUDA_PKG_VERSION \\\n        libcublas10=10.2.1.243-1 \\\n        libcudnn7=$CUDNN_VERSION-1+cuda10.1 \\\n        libnccl2=$NCCL_VERSION-1+cuda10.1 \\\n    && apt-mark hold libnccl2 \\\n    && apt-mark hold libcudnn7 \\\n    && apt-mark hold libcublas10 \\\n    && rm -rf /var/lib/apt/lists/*\n"
  },
  {
    "path": "docker/cuda-9.2.dockerfile",
    "content": "ARG BASE=python:3.7\nFROM ${BASE}\n\n# FROM 9.2-base-ubuntu18.04\n# https://gitlab.com/nvidia/container-images/cuda/blob/ubuntu18.04/9.2/base/Dockerfile\nRUN apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates \\\n    && curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1710/x86_64/7fa2af80.pub | apt-key add - \\\n    && echo \"deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1710/x86_64 /\" > /etc/apt/sources.list.d/cuda.list \\\n    && echo \"deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /\" > /etc/apt/sources.list.d/nvidia-ml.list \\\n    && apt-get purge --autoremove -y curl \\\n    && rm -rf /var/lib/apt/lists/*\nENV CUDA_VERSION 9.2.148\nENV CUDA_PKG_VERSION 9-2=$CUDA_VERSION-1\nRUN apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        cuda-cudart-$CUDA_PKG_VERSION \\\n    && ln -s cuda-9.2 /usr/local/cuda \\\n    && rm -rf /var/lib/apt/lists/*\nLABEL com.nvidia.volumes.needed=\"nvidia_driver\"\nLABEL com.nvidia.cuda.version=\"${CUDA_VERSION}\"\nRUN echo \"/usr/local/nvidia/lib\" >> /etc/ld.so.conf.d/nvidia.conf \\\n    && echo \"/usr/local/nvidia/lib64\" >> /etc/ld.so.conf.d/nvidia.conf\nENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}\nENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64\nENV NVIDIA_VISIBLE_DEVICES all\nENV NVIDIA_DRIVER_CAPABILITIES compute,utility\nENV NVIDIA_REQUIRE_CUDA \"cuda>=9.2\"\n\n# FROM 9.2-runtime-ubuntu18.04\n# https://gitlab.com/nvidia/container-images/cuda/blob/ubuntu18.04/9.2/runtime/Dockerfile\nENV NCCL_VERSION 2.3.7\nRUN apt-get update \\\n    && apt-get install -y --no-install-recommends \\\n        cuda-libraries-$CUDA_PKG_VERSION \\\n        cuda-nvtx-$CUDA_PKG_VERSION \\\n        libnccl2=$NCCL_VERSION-1+cuda9.2 \\\n    && apt-mark hold libnccl2 \\\n    && rm -rf /var/lib/apt/lists/*\n\n# FROM 9.2-runtime-cudnn7-ubuntu18.04\n# https://gitlab.com/nvidia/container-images/cuda/blob/ubuntu18.04/9.2/runtime/cudnn7/Dockerfile\nENV CUDNN_VERSION 7.5.0.56\nLABEL com.nvidia.cudnn.version=\"${CUDNN_VERSION}\"\nRUN apt-get update \\\n    && apt-get install -y --no-install-recommends libcudnn7=$CUDNN_VERSION-1+cuda9.2 \\\n    && apt-mark hold libcudnn7 \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN mkdir -p /model\nENV MODEL_PATH /model\nCOPY audio_example.mp3 .\n\n# Spleeter installation.\nRUN apt-get update && apt-get install -y ffmpeg libsndfile1\nRUN pip install musdb museval\nRUN pip install spleeter-gpu==1.4.9\n\nENTRYPOINT [\"spleeter\"]"
  },
  {
    "path": "docker/spleeter-conda.dockerfile",
    "content": "ARG BASE=conda\n\nFROM ${BASE}\n\nARG SPLEETER_VERSION=1.5.3\nENV MODEL_PATH /model\n\nRUN mkdir -p /model\nRUN conda config --add channels conda-forge\nRUN conda install -y -c conda-forge musdb\nRUN conda install -y -c deezer-research spleeter \nCOPY docker/conda-entrypoint.sh spleeter-entrypoint.sh\nENTRYPOINT [\"/bin/bash\", \"spleeter-entrypoint.sh\"]"
  },
  {
    "path": "docker/spleeter-model.dockerfile",
    "content": "ARG BASE=researchdeezer/spleeter\n\nFROM ${BASE}\n\nARG MODEL=2stems\nRUN mkdir -p /model/$MODEL \\\n    && wget -O /tmp/$MODEL.tar.gz https://github.com/deezer/spleeter/releases/download/v1.4.0/$MODEL.tar.gz \\\n    && tar -xvzf /tmp/$MODEL.tar.gz -C /model/$MODEL/ \\\n    && touch /model/$MODEL/.probe\n"
  },
  {
    "path": "docker/spleeter.dockerfile",
    "content": "ARG BASE=python:3.6\n\nFROM ${BASE}\n\nARG SPLEETER_VERSION=1.5.3\nENV MODEL_PATH /model\n\nRUN mkdir -p /model\nRUN apt-get update && apt-get install -y ffmpeg libsndfile1\nRUN pip install musdb museval\nRUN pip install spleeter==${SPLEETER_VERSION}\n\nENTRYPOINT [\"spleeter\"]\n"
  },
  {
    "path": "paper.bib",
    "content": "% bibtex\n\n@inproceedings{SISEC18,\n       author = {{St{\\\"o}ter}, Fabian-Robert and {Liutkus}, Antoine and {Ito}, Nobutaka},\n        title = {The 2018 Signal Separation Evaluation Campaign},\n        year = {2018},\n        booktitle = {Latent Variable Analysis and Signal Separation. {LVA}/{ICA}},\n        vol={10891},\n        doi = {10.1007/978-3-319-93764-9_28},\n        publisher = { Springer, Cham}\n}\n\n@misc{spleeter2019,\n  title={Spleeter: A Fast And State-of-the Art Music Source Separation Tool With Pre-trained Models},\n  author={Romain Hennequin and Anis Khlif and Felix Voituret and Manuel Moussallam},\n  howpublished={Late-Breaking/Demo ISMIR 2019},\n  month={November},\n  note={Deezer Research},\n  year={2019}\n}\n\n@inproceedings{unet2017,\n  title={Singing voice separation with deep U-Net convolutional networks},\n  author={Jansson, Andreas and Humphrey, Eric J. and Montecchio, Nicola and Bittner, Rachel and Kumar, Aparna and Weyde, Tillman},\n  booktitle={Proceedings of the International Society for Music Information Retrieval Conference (ISMIR)},\n  pages={323--332},\n  year={2017}\n}\n\n@inproceedings{deezerICASSP2019,\nauthor={Laure {Pr\\'etet} and Romain {Hennequin} and Jimena {Royo-Letelier} and Andrea {Vaglio}},\nbooktitle={ICASSP 2019 - 2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},\ntitle={Singing Voice Separation: A Study on Training Data},\nyear={2019},\nvolume={},\nnumber={},\npages={506-510},\nkeywords={feature extraction;source separation;speech processing;supervised training;separation quality;data augmentation;singing voice separation systems;singing voice separation algorithms;separation diversity;source separation;supervised learning;training data;data augmentation},\ndoi={10.1109/ICASSP.2019.8683555},\nISSN={},\nmonth={May},}\n\n\n@misc{Norbert,\n  author       = {Antoine Liutkus and\n                  Fabian-Robert St{\\\"o}ter},\n  title        = {sigsep/norbert: First official Norbert release},\n  month        = jul,\n  year         = 2019,\n  doi          = {10.5281/zenodo.3269749},\n  url          = {https://doi.org/10.5281/zenodo.3269749}\n}\n\n@ARTICLE{separation_metrics,\nauthor={Emmanuel {Vincent} and Remi {Gribonval} and Cedric {Fevotte}},\njournal={IEEE Transactions on Audio, Speech, and Language Processing},\ntitle={Performance measurement in blind audio source separation},\nyear={2006},\nvolume={14},\nnumber={4},\npages={1462-1469},\nkeywords={audio signal processing;blind source separation;distortion;time-varying filters;blind audio source separation;distortions;time-invariant gains;time-varying filters;source estimation;interference;additive noise;algorithmic artifacts;Source separation;Data mining;Filters;Additive noise;Microphones;Distortion measurement;Energy measurement;Independent component analysis;Interference;Image analysis;Audio source separation;evaluation;measure;performance;quality},\ndoi={10.1109/TSA.2005.858005},\nISSN={},\nmonth={July},}\n\n@misc{musdb18,\n  author       = {Rafii, Zafar and\n                  Liutkus, Antoine and\n                  Fabian-Robert St{\\\"o}ter and\n                  Mimilakis, Stylianos Ioannis and\n                  Bittner, Rachel},\n  title        = {The {MUSDB18} corpus for music separation},\n  month        = dec,\n  year         = 2017,\n  doi          = {10.5281/zenodo.1117372},\n  url          = {https://doi.org/10.5281/zenodo.1117372}\n}\n\n\n@misc{tensorflow2015-whitepaper,\ntitle={ {TensorFlow}: Large-Scale Machine Learning on Heterogeneous Systems},\nurl={https://www.tensorflow.org/},\nnote={Software available from tensorflow.org},\nauthor={\n    Abadi, Mart{\\'{\\i}}n et al.},\n  year={2015},\n}\n\n@article{2019arXiv190611139L,\n       author = {{Lee}, Kyungyun and {Nam}, Juhan},\n        title = \"{Learning a Joint Embedding Space of Monophonic and Mixed Music Signals for Singing Voice}\",\n      journal = {arXiv e-prints},\n     keywords = {Computer Science - Sound, Electrical Engineering and Systems Science - Audio and Speech Processing},\n         year = \"2019\",\n        month = \"Jun\",\n          eid = {arXiv:1906.11139},\n        pages = {arXiv:1906.11139},\narchivePrefix = {arXiv},\n       eprint = {1906.11139},\n primaryClass = {cs.SD},\n       adsurl = {https://ui.adsabs.harvard.edu/abs/2019arXiv190611139L},\n      adsnote = {Provided by the SAO/NASA Astrophysics Data System}\n}\n\n@article{Adam,\n       author = {{Kingma}, Diederik P. and {Ba}, Jimmy},\n        title = \"{Adam: A Method for Stochastic Optimization}\",\n      journal = {arXiv e-prints},\n     keywords = {Computer Science - Machine Learning},\n         year = \"2014\",\n        month = \"Dec\",\n          eid = {arXiv:1412.6980},\n        pages = {arXiv:1412.6980},\narchivePrefix = {arXiv},\n       eprint = {1412.6980},\n primaryClass = {cs.LG},\n       adsurl = {https://ui.adsabs.harvard.edu/abs/2014arXiv1412.6980K},\n      adsnote = {Provided by the SAO/NASA Astrophysics Data System}\n}\n\n@article{Open-Unmix,\n  author={Fabian-Robert St\\\"{o}ter and Stefan Uhlich and Antoine Liutkus and Yuki Mitsufuji},\n  title={Open-Unmix - A Reference Implementation for Music Source Separation},\n  journal={Journal of Open Source Software},\n  year=2019,\n  doi = {10.21105/joss.01667},\n  url = {https://doi.org/10.21105/joss.01667}\n}\n\n@misc{spleeter,\n  author={Romain Hennequin and Anis Khlif and Felix Voituret and Manuel Moussallam},\n  title={Spleeter},\n  year=2019,\n  url = {https://www.github.com/deezer/spleeter}\n}\n\n@misc{demucs,\n    title={Music Source Separation in the Waveform Domain},\n    author={Alexandre Défossez and Nicolas Usunier and Léon Bottou and Francis Bach},\n    year={2019},\n    eprint={1911.13254},\n    archivePrefix={arXiv},\n    primaryClass={cs.SD}\n}"
  },
  {
    "path": "paper.md",
    "content": "---\ntitle: 'Spleeter: a fast and efficient music source separation tool with pre-trained models'\ntags:\n  - Python\n  - musical signal processing\n  - source separation\n  - vocal isolation\nauthors:\n  - name:  Romain Hennequin\n    orcid: 0000-0001-8158-5562\n    affiliation: 1\n  - name: Anis Khlif\n    affiliation: 1\n  - name: Felix Voituret\n    affiliation: 1\n  - name: Manuel Moussallam\n    orcid: 0000-0003-0886-5423\n    affiliation: 1\naffiliations:\n - name: Deezer Research, Paris\n   index: 1\ndate: 04 March 2020\nbibliography: paper.bib\n\n---\n\n## Summary\n\nWe present and release a new tool for music source separation with pre-trained models called Spleeter. Spleeter was designed with ease of use, separation performance, and speed in mind. Spleeter is based on Tensorflow [@tensorflow2015-whitepaper] and makes it possible to:\n\n- split music audio files into several stems with a single command line using pre-trained models. A music audio file can be separated into $2$ stems (vocals and accompaniments), $4$ stems (vocals, drums, bass, and other) or $5$ stems (vocals, drums, bass, piano and other).\n- train source separation models or fine-tune pre-trained ones with Tensorflow (provided you have a dataset of isolated sources).\n\nThe performance of the pre-trained models are very close to the published state-of-the-art and is one of the best performing $4$ stems separation model on the common musdb18 benchmark [@musdb18] to be publicly released. Spleeter is also very fast as it can separate a mix audio file into $4$ stems $100$ times faster than real-time (we note, though, that the model cannot be applied in real-time as it needs buffering) on a single Graphics Processing Unit (GPU) using the pre-trained $4$-stems model.\n\n## Purpose\n\nWe release Spleeter with pre-trained state-of-the-art models in order to help the Music Information Retrieval (MIR) research community leverage the power of source separation in various MIR tasks, such as vocal lyrics analysis from audio (audio/lyrics alignment, lyrics transcription...), music transcription (chord transcription, drums transcription, bass transcription, chord estimation, beat tracking), singer identification, any type of multilabel classification (mood/genre...), vocal melody extraction or cover detection.\nWe believe that source separation has reached a level of maturity that makes it worth considering for these tasks and that specific features computed from isolated vocals, drums or bass may help increase performances, especially in low data availability scenarios (small datasets, limited annotation availability) for which supervised learning might be difficult.\nSpleeter also makes it possible to fine-tune the provided state-of-the-art models in order to adapt the system to a specific use-case.\nFinally, having an available source separation tool such as Spleeter will allow researchers to compare performances of their new models to a state-of-the-art one on their private datasets instead of musdb18, which is usually the only used dataset for reporting separation performances for unreleased models.\nNote that we cannot release the training data for copyright reasons, and thus, sharing pre-trained models were the only way to make these results available to the community.\n\n## Implementation details\n\nSpleeter contains pre-trained models for:\n\n- vocals/accompaniment separation.\n- $4$ stems separation as in SiSec [@SISEC18]  (vocals, bass, drums and other).\n- $5$ stems separation with an extra piano stem (vocals, bass, drums, piano, and other). It is, to the authors' knowledge, the first released model to perform such a separation.\n\nThe pre-trained models are U-nets [@unet2017] and follow similar specifications as in [@deezerICASSP2019]. The U-net is an encoder/decoder Convolutional Neural Network (CNN) architecture with skip connections. We used $12$-layer U-nets ($6$ layers for the encoder and $6$ for the decoder). A U-net is used for estimating a soft mask for each source (stem). Training loss is a $L_1$-norm between masked input mix spectrograms and source-target spectrograms. The models were trained on Deezer's internal datasets (noteworthily the Bean dataset that was used in [@deezerICASSP2019]) using Adam [@Adam]. Training time took approximately a full week on a single GPU. Separation is then done from estimated source spectrograms using soft masking or multi-channel Wiener filtering.\n\nTraining and inference are implemented in Tensorflow which makes it possible to run the code on Central Processing Unit (CPU) or GPU.\n\n## Speed\n\nAs the whole separation pipeline can be run on a GPU and the model is based on a CNN, computations are efficiently parallelized and model inference is very fast. For instance, Spleeter is able to separate the whole musdb18 test dataset (about $3$ hours and $27$ minutes of audio) into $4$ stems in less than $2$ minutes, including model loading time (about $15$ seconds), and audio wav files export, using a single GeForce RTX 2080 GPU, and a double Intel Xeon Gold 6134 CPU @ 3.20GHz (CPU is used for mix files loading and stem files export only). In this setup, Spleeter is able to process $100$ seconds of stereo audio in less than $1$ second, which makes it very useful for efficiently processing large datasets.\n\n## Separation performances\n\nThe models compete with the state-of-the-art on the standard musdb18 dataset [@musdb18] while it was not trained, validated or optimized in any way with musdb18 data. We report results in terms of standard source separation metrics [@separation_metrics], namely Signal to Distortion Ratio (SDR), Signal to Artifacts Ratio (SAR), Signal to Interference Ratio (SIR) and source Image to Spatial distortion Ratio (ISR), are presented in the following table compared to Open-Unmix [@Open-Unmix] and Demucs [@demucs] (only SDR are reported for Demucs since other metrics are not available in the paper) which are, to the authors' knowledge, the only released system that performs near state-of-the-art performances.\nWe present results for soft masking and for multi-channel Wiener filtering (applied using Norbert [@Norbert]). As can be seen, for most metrics Spleeter is competitive with Open-Unmix and especially on SDR for all instruments, and is almost on par with Demucs.\n\n\n|           |Spleeter Mask  |Spleeter MWF   |Open-Unmix |Demucs|\n|-----------|---------------|---------------|-----------|------|\n| Vocals SDR|6.55           |6.86           |6.32       |7.05  |\n| Vocals SIR|15.19          |15.86          |13.33      |13.94 |\n| Vocals SAR|6.44           |6.99           |6.52       |7.00  |\n| Vocals ISR|12.01          |11.95          |11.93      |12.04 |\n| Bass SDR  |5.10           |5.51           |5.23       |6.70  |\n| Bass SIR  |10.01          |10.30          |10.93      |13.03 |\n| Bass SAR  |5.15           |5.96           |6.34       |6.68  |\n| Bass ISR  |9.18           |9.61           |9.23       |9.99  |\n| Drums SDR |5.93           |6.71           |5.73       |7.08  |\n| Drums SIR |12.24          |13.67          |11.12      |13.74 |\n| Drums SAR |5.78           |6.54           |6.02       |7.04  |\n| Drums ISR |10.50          |10.69          |10.51      |11.96 |\n| Other SDR |4.24           |4.55           |4.02       |4.47  |\n| Other SIR |7.86           |8.16           |6.59       |7.11  |\n| Other SAR |4.63           |4.88           |4.74       |5.26  |\n| Other ISR |9.83           |9.87           |9.31       |10.86 |\n\nSpleeter [@spleeter] source code and pre-trained models are available on [github](https://www.github.com/deezer/spleeter) and distributed under a MIT license. This repository will eventually be used for releasing other models with improved performances or models separating into more than $5$ stems in the future.\n\n## Distribution\n\nSpleeter is available as a standalone Python package, and also provided as a [conda](https://github.com/conda-forge/spleeter-feedstock) recipe and self-contained [Dockers](https://hub.docker.com/r/researchdeezer/spleeter) which makes it usable as-is on various platforms.\n\n## Acknowledgements\n\nWe acknowledge contributions from Laure Pretet who trained first models and wrote the first piece of code that lead to Spleeter.\n\n## References\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.poetry]\nname = \"spleeter\"\nversion = \"2.4.2\"\ndescription = \"The Deezer source separation library with pretrained models based on tensorflow.\"\nauthors = [\"Deezer Research <spleeter@deezer.com>\"]\nlicense = \"MIT License\"\nreadme = \"README.md\"\nrepository = \"https://github.com/deezer/spleeter\"\nhomepage = \"https://github.com/deezer/spleeter\"\nclassifiers = [\n    \"Environment :: Console\",\n    \"Environment :: MacOS X\",\n    \"Intended Audience :: Developers\",\n    \"Intended Audience :: Information Technology\",\n    \"Intended Audience :: Science/Research\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Natural Language :: English\",\n    \"Operating System :: MacOS\",\n    \"Operating System :: Microsoft :: Windows\",\n    \"Operating System :: POSIX :: Linux\",\n    \"Operating System :: Unix\",\n    \"Programming Language :: Python\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3.8\",\n    \"Programming Language :: Python :: 3.9\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3 :: Only\",\n    \"Programming Language :: Python :: Implementation :: CPython\",\n    \"Topic :: Artistic Software\",\n    \"Topic :: Multimedia\",\n    \"Topic :: Multimedia :: Sound/Audio\",\n    \"Topic :: Multimedia :: Sound/Audio :: Analysis\",\n    \"Topic :: Multimedia :: Sound/Audio :: Conversion\",\n    \"Topic :: Multimedia :: Sound/Audio :: Sound Synthesis\",\n    \"Topic :: Scientific/Engineering\",\n    \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n    \"Topic :: Scientific/Engineering :: Information Analysis\",\n    \"Topic :: Software Development\",\n    \"Topic :: Software Development :: Libraries\",\n    \"Topic :: Software Development :: Libraries :: Python Modules\",\n    \"Topic :: Utilities\"\n]\npackages = [ { include = \"spleeter\" } ]\ninclude = [\"LICENSE\", \"spleeter/resources/*.json\"]\n\n[tool.poetry.dependencies]\npython = \">=3.8,<3.12\"\nffmpeg-python = \"^0.2.0\"\nhttpx = {extras = [\"http2\"], version = \"^0.19.0\"}\ntyper = \"^0.3.2\"\nmusdb = {version = \"^0.4.0\", optional = true}\nmuseval = {version = \"^0.4.0\", optional = true}\ntensorflow-io-gcs-filesystem = \"0.32.0\"\ntensorflow = \"2.12.1\"\npandas = \"^1.3.0\"\nnorbert = \"^0.2.1\"\nnumpy = \"<2.0.0\"\n\n[tool.poetry.dev-dependencies]\npytest = \"^6.2.1\"\nisort = \"^5.7.0\"\nblack = \"^21.7b0\"\nmypy = \">0.790,<1.0\"\nflake8 = \"^5.0.0\"\npytest-forked = \"^1.3.0\"\nmusdb = \"^0.4.0\"\nmuseval = \"^0.4.0\"\n\n[tool.poetry.scripts]\nspleeter = 'spleeter.__main__:entrypoint'\n\n[tool.poetry.extras]\nevaluation = [\"musdb\", \"museval\"]\n\n[tool.isort]\nprofile = \"black\"\nmulti_line_output = 3\n\n[tool.pytest.ini_options]\naddopts = \"-W ignore::FutureWarning -W ignore::DeprecationWarning -vv --forked\"\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n"
  },
  {
    "path": "spleeter/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nSpleeter is the Deezer source separation library with pretrained models.\nThe library is based on Tensorflow:\n\n-   It provides already trained model for performing separation.\n-   It makes it easy to train source separation model with tensorflow\n    (provided you have a dataset of isolated sources).\n\nThis module allows to interact easily from command line with Spleeter\nby providing train, evaluation and source separation action.\n\"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\nclass SpleeterError(Exception):\n    \"\"\"Custom exception for Spleeter related error.\"\"\"\n\n    pass\n"
  },
  {
    "path": "spleeter/__main__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nPython oneliner script usage.\n\nUSAGE: python -m spleeter {train,evaluate,separate} ...\n\nNotes:\n    All critical import involving TF, numpy or Pandas are deported to\n    command function scope to avoid heavy import on CLI evaluation,\n    leading to large bootstraping time.\n\"\"\"\nimport json\nfrom functools import partial\nfrom glob import glob\nfrom itertools import product\nfrom os.path import join\nfrom typing import Dict, List, Optional, Tuple\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nfrom typer import Exit, Typer\n\nfrom . import SpleeterError\nfrom .audio import Codec\nfrom .options import (\n    AudioAdapterOption,\n    AudioBitrateOption,\n    AudioCodecOption,\n    AudioDurationOption,\n    AudioInputArgument,\n    AudioInputOption,\n    AudioOffsetOption,\n    AudioOutputOption,\n    FilenameFormatOption,\n    ModelParametersOption,\n    MUSDBDirectoryOption,\n    MWFOption,\n    TrainingDataDirectoryOption,\n    VerboseOption,\n    VersionOption,\n)\nfrom .utils.logging import configure_logger, logger\n\n# pylint: enable=import-error\n\nspleeter: Typer = Typer(add_completion=False, no_args_is_help=True, short_help=\"-h\")\n\"\"\" CLI application. \"\"\"\n\n\n@spleeter.callback()\ndef default(\n    version: bool = VersionOption,\n) -> None:\n    pass\n\n\n@spleeter.command(no_args_is_help=True)\ndef train(\n    adapter: str = AudioAdapterOption,\n    data: str = TrainingDataDirectoryOption,\n    params_filename: str = ModelParametersOption,\n    verbose: bool = VerboseOption,\n) -> None:\n    \"\"\"\n    Train a source separation model\n    \"\"\"\n    import tensorflow as tf  # type: ignore\n\n    from .audio.adapter import AudioAdapter\n    from .dataset import get_training_dataset, get_validation_dataset\n    from .model import model_fn\n    from .model.provider import ModelProvider\n    from .utils.configuration import load_configuration\n\n    configure_logger(verbose)\n    audio_adapter = AudioAdapter.get(adapter)\n    params = load_configuration(params_filename)\n    session_config = tf.compat.v1.ConfigProto()\n    session_config.gpu_options.per_process_gpu_memory_fraction = 0.45\n    estimator = tf.estimator.Estimator(\n        model_fn=model_fn,\n        model_dir=params[\"model_dir\"],\n        params=params,\n        config=tf.estimator.RunConfig(\n            save_checkpoints_steps=params[\"save_checkpoints_steps\"],\n            tf_random_seed=params[\"random_seed\"],\n            save_summary_steps=params[\"save_summary_steps\"],\n            session_config=session_config,\n            log_step_count_steps=10,\n            keep_checkpoint_max=2,\n        ),\n    )\n    input_fn = partial(get_training_dataset, params, audio_adapter, data)\n    train_spec = tf.estimator.TrainSpec(\n        input_fn=input_fn, max_steps=params[\"train_max_steps\"]\n    )\n    input_fn = partial(get_validation_dataset, params, audio_adapter, data)\n    evaluation_spec = tf.estimator.EvalSpec(\n        input_fn=input_fn, steps=None, throttle_secs=params[\"throttle_secs\"]\n    )\n    logger.info(\"Start model training\")\n    tf.estimator.train_and_evaluate(estimator, train_spec, evaluation_spec)\n    ModelProvider.writeProbe(params[\"model_dir\"])\n    logger.info(\"Model training done\")\n\n\n@spleeter.command(no_args_is_help=True)\ndef separate(\n    deprecated_files: Optional[str] = AudioInputOption,\n    files: List[str] = AudioInputArgument,\n    adapter: str = AudioAdapterOption,\n    bitrate: str = AudioBitrateOption,\n    codec: Codec = AudioCodecOption,\n    duration: float = AudioDurationOption,\n    offset: float = AudioOffsetOption,\n    output_path: str = AudioOutputOption,\n    filename_format: str = FilenameFormatOption,\n    params_filename: str = ModelParametersOption,\n    mwf: bool = MWFOption,\n    verbose: bool = VerboseOption,\n) -> None:\n    \"\"\"\n    Separate audio file(s)\n    \"\"\"\n    from .audio.adapter import AudioAdapter\n    from .separator import Separator\n\n    configure_logger(verbose)\n    if deprecated_files is not None:\n        logger.error(\n            \"⚠️ -i option is not supported anymore, audio files must be supplied \"\n            \"using input argument instead (see spleeter separate --help)\"\n        )\n        raise Exit(20)\n    audio_adapter: AudioAdapter = AudioAdapter.get(adapter)\n    separator: Separator = Separator(params_filename, MWF=mwf)\n\n    for filename in files:\n        separator.separate_to_file(\n            filename,\n            output_path,\n            audio_adapter=audio_adapter,\n            offset=offset,\n            duration=duration,\n            codec=codec,\n            bitrate=bitrate,\n            filename_format=filename_format,\n            synchronous=False,\n        )\n    separator.join()\n\n\nEVALUATION_SPLIT: str = \"test\"\nEVALUATION_METRICS_DIRECTORY: str = \"metrics\"\nEVALUATION_INSTRUMENTS: Tuple[str, ...] = (\"vocals\", \"drums\", \"bass\", \"other\")\nEVALUATION_METRICS: Tuple[str, ...] = (\"SDR\", \"SAR\", \"SIR\", \"ISR\")\nEVALUATION_MIXTURE: str = \"mixture.wav\"\nEVALUATION_AUDIO_DIRECTORY: str = \"audio\"\n\n\ndef _compile_metrics(metrics_output_directory: str) -> Dict:\n    \"\"\"\n    Compiles metrics from given directory and returns results as dict.\n\n    Parameters:\n        metrics_output_directory (str):\n            Directory to get metrics from.\n\n    Returns:\n        Dict:\n            Compiled metrics as dict.\n    \"\"\"\n    import numpy as np\n    import pandas as pd  # type: ignore\n\n    songs = glob(join(metrics_output_directory, \"test/*.json\"))\n    index = pd.MultiIndex.from_tuples(\n        product(EVALUATION_INSTRUMENTS, EVALUATION_METRICS),\n        names=[\"instrument\", \"metric\"],\n    )\n    pd.DataFrame([], index=[\"config1\", \"config2\"], columns=index)\n    metrics: Dict = {\n        instrument: {k: [] for k in EVALUATION_METRICS}\n        for instrument in EVALUATION_INSTRUMENTS\n    }\n    for song in songs:\n        with open(song, \"r\") as stream:\n            data = json.load(stream)\n        for target in data[\"targets\"]:\n            instrument = target[\"name\"]\n            for metric in EVALUATION_METRICS:\n                sdr_med = np.median(\n                    [\n                        frame[\"metrics\"][metric]\n                        for frame in target[\"frames\"]\n                        if not np.isnan(frame[\"metrics\"][metric])\n                    ]\n                )\n                metrics[instrument][metric].append(sdr_med)\n    return metrics\n\n\n@spleeter.command(no_args_is_help=True)\ndef evaluate(\n    adapter: str = AudioAdapterOption,\n    output_path: str = AudioOutputOption,\n    params_filename: str = ModelParametersOption,\n    mus_dir: str = MUSDBDirectoryOption,\n    mwf: bool = MWFOption,\n    verbose: bool = VerboseOption,\n) -> Dict:\n    \"\"\"\n    Evaluate a model on the musDB test dataset\n    \"\"\"\n    import numpy as np\n\n    configure_logger(verbose)\n    try:\n        import musdb  # type: ignore\n        import museval  # type: ignore\n    except ImportError:\n        logger.error(\"Extra dependencies musdb and museval not found\")\n        logger.error(\"Please install musdb and museval first, abort\")\n        raise Exit(10)\n    # Separate musdb sources.\n    songs = glob(join(mus_dir, EVALUATION_SPLIT, \"*/\"))\n    mixtures = [join(song, EVALUATION_MIXTURE) for song in songs]\n    audio_output_directory = join(output_path, EVALUATION_AUDIO_DIRECTORY)\n    separate(\n        deprecated_files=None,\n        files=mixtures,\n        adapter=adapter,\n        bitrate=\"128k\",\n        codec=Codec.WAV,\n        duration=600.0,\n        offset=0,\n        output_path=join(audio_output_directory, EVALUATION_SPLIT),\n        filename_format=\"{foldername}/{instrument}.{codec}\",\n        params_filename=params_filename,\n        mwf=mwf,\n        verbose=verbose,\n    )\n    # Compute metrics with musdb.\n    metrics_output_directory = join(output_path, EVALUATION_METRICS_DIRECTORY)\n    logger.info(\"Starting musdb evaluation (this could be long) ...\")\n    dataset = musdb.DB(root=mus_dir, is_wav=True, subsets=[EVALUATION_SPLIT])\n    museval.eval_mus_dir(\n        dataset=dataset,\n        estimates_dir=audio_output_directory,\n        output_dir=metrics_output_directory,\n    )\n    logger.info(\"musdb evaluation done\")\n    # Compute and pretty print median metrics.\n    metrics = _compile_metrics(metrics_output_directory)\n    for instrument, metric in metrics.items():\n        logger.info(f\"{instrument}:\")\n        for metric, value in metric.items():\n            logger.info(f\"{metric}: {np.median(value):.3f}\")\n    return metrics\n\n\ndef entrypoint():\n    \"\"\"Application entrypoint.\"\"\"\n    try:\n        spleeter()\n    except SpleeterError as e:\n        logger.error(e)\n\n\nif __name__ == \"__main__\":\n    entrypoint()\n"
  },
  {
    "path": "spleeter/audio/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\n`spleeter.utils.audio` package provides various\ntools for manipulating audio content such as :\n\n- Audio adapter class for abstract interaction with audio file.\n- FFMPEG implementation for audio adapter.\n- Waveform convertion and transforming functions.\n\"\"\"\n\n# Python 3.11 is not backward compatible for String Enum\n# https://tomwojcik.com/posts/2023-01-02/python-311-str-enum-breaking-change\ntry:\n    from enum import StrEnum\nexcept ImportError:\n    from enum import Enum\n\n    class StrEnum(str, Enum):\n        pass\n\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\nclass Codec(StrEnum):\n    \"\"\"Enumeration of supported audio codec.\"\"\"\n\n    WAV: str = \"wav\"\n    MP3: str = \"mp3\"\n    OGG: str = \"ogg\"\n    M4A: str = \"m4a\"\n    WMA: str = \"wma\"\n    FLAC: str = \"flac\"\n"
  },
  {
    "path": "spleeter/audio/adapter.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" AudioAdapter class defintion. \"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom importlib import import_module\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Union\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport numpy as np\nimport tensorflow as tf  # type: ignore\n\nfrom spleeter.audio import Codec\n\nfrom .. import SpleeterError\nfrom ..types import AudioDescriptor, Signal\nfrom ..utils.logging import logger\n\n# pylint: enable=import-error\n\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\nclass AudioAdapter(ABC):\n    \"\"\"An abstract class for manipulating audio signal.\"\"\"\n\n    _DEFAULT: Optional[\"AudioAdapter\"] = None\n    \"\"\"Default audio adapter singleton instance.\"\"\"\n\n    @abstractmethod\n    def load(\n        self,\n        audio_descriptor: AudioDescriptor,\n        offset: Optional[float] = None,\n        duration: Optional[float] = None,\n        sample_rate: Optional[float] = None,\n        dtype: bytes = b\"float32\",\n    ) -> Signal:\n        \"\"\"\n        Loads the audio file denoted by the given audio descriptor and\n        returns it data as a waveform. Aims to be implemented by client.\n\n        Parameters:\n            audio_descriptor (AudioDescriptor):\n                Describe song to load, in case of file based audio adapter,\n                such descriptor would be a file path.\n            offset (Optional[float]):\n                (Optional) Start offset to load from in seconds.\n            duration (Optional[float]):\n                (Optional) Duration to load in seconds.\n            sample_rate (Optional[float]):\n                (Optional) Sample rate to load audio with.\n            dtype (bytes):\n                (Optional) Data type to use, default to `b'float32'`.\n\n        Returns:\n            Signal:\n                Loaded data as (wf, sample_rate) tuple.\n        \"\"\"\n        pass\n\n    def load_waveform(\n        self,\n        audio_descriptor,\n        offset: float = 0.0,\n        duration: float = 1800.0,\n        sample_rate: int = 44100,\n        dtype: bytes = b\"float32\",\n        waveform_name: str = \"waveform\",\n    ) -> Dict[str, Any]:\n        \"\"\"\n        Load the audio and convert it to a tensorflow waveform.\n\n        Parameters:\n            audio_descriptor (Any):\n                Describe song to load, in case of file based audio adapter,\n                such descriptor would be a file path.\n            offset (float):\n                (Optional) Start offset to load from in seconds.\n            duration (float):\n                (Optional) Duration to load in seconds.\n            sample_rate (float):\n                (Optional) Sample rate to load audio with.\n            dtype (bytes):\n                (Optional) Data type to use, default to `b'float32'`.\n            waveform_name (str):\n                (Optional) Name of the key in output dict, default to\n                `'waveform'`.\n\n        Returns:\n            Dict[str, Any]:\n                TF output dict with waveform as `(T x chan numpy array)`\n                and a boolean that tells whether there were an error while\n                trying to load the waveform.\n        \"\"\"\n        # Cast parameters to TF format.\n        offset = tf.cast(offset, tf.float64)\n        duration = tf.cast(duration, tf.float64)\n\n        # Defined safe loading function.\n        def safe_load(path, offset, duration, sample_rate, dtype):\n            logger.info(f\"Loading audio {path} from {offset} to {offset + duration}\")\n            try:\n                (data, _) = self.load(\n                    path.numpy(),\n                    offset.numpy(),\n                    duration.numpy(),\n                    sample_rate.numpy(),\n                    dtype=dtype.numpy(),\n                )\n                logger.info(\"Audio data loaded successfully\")\n                return (data, False)\n            except Exception as e:\n                logger.exception(\"An error occurs while loading audio\", exc_info=e)\n            return (np.float32(-1.0), True)\n\n        # Execute function and format results.\n        results = (\n            tf.py_function(\n                safe_load,\n                [audio_descriptor, offset, duration, sample_rate, dtype],\n                (tf.float32, tf.bool),\n            ),\n        )\n        waveform, error = results[0]\n        return {waveform_name: waveform, f\"{waveform_name}_error\": error}\n\n    @abstractmethod\n    def save(\n        self,\n        path: Union[Path, str],\n        data: np.ndarray,\n        sample_rate: float,\n        codec: Codec = None,\n        bitrate: str = None,\n    ) -> None:\n        \"\"\"\n        Save the given audio data to the file denoted by the given path.\n\n        Parameters:\n            path (Union[Path, str]):\n                Path like of the audio file to save data in.\n            data (np.ndarray):\n                Waveform data to write.\n            sample_rate (float):\n                Sample rate to write file in.\n            codec (Codec):\n                (Optional) Writing codec to use, default to `None`.\n            bitrate (str):\n                (Optional) Bitrate of the written audio file, default to\n                `None`.\n        \"\"\"\n        pass\n\n    @classmethod\n    def default(cls) -> \"AudioAdapter\":\n        \"\"\"\n        Builds and returns a default audio adapter instance.\n\n        Returns:\n            AudioAdapter:\n                Default adapter instance to use.\n        \"\"\"\n        if cls._DEFAULT is None:\n            from .ffmpeg import FFMPEGProcessAudioAdapter\n\n            cls._DEFAULT = FFMPEGProcessAudioAdapter()\n        return cls._DEFAULT\n\n    @classmethod\n    def get(cls, descriptor: str) -> \"AudioAdapter\":\n        \"\"\"\n        Load dynamically an AudioAdapter from given class descriptor.\n\n        Parameters:\n            descriptor (str):\n                Adapter class descriptor (module.Class)\n\n        Returns:\n            AudioAdapter:\n                Created adapter instance.\n        \"\"\"\n        if not descriptor:\n            return cls.default()\n        module_desc: List[str] = descriptor.split(\".\")\n        adapter_class_name: str = module_desc[-1]\n        module_path: str = \".\".join(module_desc[:-1])\n        adapter_module = import_module(module_path)\n        adapter_class = getattr(adapter_module, adapter_class_name)\n        if not issubclass(adapter_class, AudioAdapter):\n            raise SpleeterError(\n                f\"{adapter_class_name} is not a valid AudioAdapter class\"\n            )\n        return adapter_class()\n"
  },
  {
    "path": "spleeter/audio/convertor.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" This module provides audio data convertion functions. \"\"\"\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport numpy as np\nimport tensorflow as tf  # type: ignore\n\nfrom ..utils.tensor import from_float32_to_uint8, from_uint8_to_float32\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef to_n_channels(waveform: tf.Tensor, n_channels: int) -> tf.Tensor:\n    \"\"\"\n    Convert a waveform to n_channels by removing or duplicating channels if\n    needed (in tensorflow).\n\n    Parameters:\n        waveform (tf.Tensor):\n            Waveform to transform.\n        n_channels (int):\n            Number of channel to reshape waveform in.\n\n    Returns:\n        tf.Tensor:\n            Reshaped waveform.\n    \"\"\"\n    return tf.cond(\n        tf.shape(waveform)[1] >= n_channels,\n        true_fn=lambda: waveform[:, :n_channels],\n        false_fn=lambda: tf.tile(waveform, [1, n_channels])[:, :n_channels],\n    )\n\n\ndef to_stereo(waveform: np.ndarray) -> np.ndarray:\n    \"\"\"\n    Convert a waveform to stereo by duplicating if mono, or truncating\n    if too many channels.\n\n    Parameters:\n        waveform (np.ndarray):\n            a `(N, d)` numpy array.\n\n    Returns:\n        np.ndarray:\n            A stereo waveform as a `(N, 1)` numpy array.\n    \"\"\"\n    if waveform.shape[1] == 1:\n        return np.repeat(waveform, 2, axis=-1)\n    if waveform.shape[1] > 2:\n        return waveform[:, :2]\n    return waveform\n\n\ndef gain_to_db(tensor: tf.Tensor, espilon: float = 10e-10) -> tf.Tensor:\n    \"\"\"\n    Convert from gain to decibel in tensorflow.\n\n    Parameters:\n        tensor (tf.Tensor):\n            Tensor to convert\n        epsilon (float):\n            Operation constant.\n\n    Returns:\n        tf.Tensor:\n            Converted tensor.\n    \"\"\"\n    return 20.0 / np.log(10) * tf.math.log(tf.maximum(tensor, espilon))\n\n\ndef db_to_gain(tensor: tf.Tensor) -> tf.Tensor:\n    \"\"\"\n    Convert from decibel to gain in tensorflow.\n\n    Parameters:\n        tensor (tf.Tensor):\n            Tensor to convert\n\n    Returns:\n        tf.Tensor:\n            Converted tensor.\n    \"\"\"\n    return tf.pow(10.0, (tensor / 20.0))\n\n\ndef spectrogram_to_db_uint(\n    spectrogram: tf.Tensor, db_range: float = 100.0, **kwargs\n) -> tf.Tensor:\n    \"\"\"\n    Encodes given spectrogram into uint8 using decibel scale.\n\n    Parameters:\n        spectrogram (tf.Tensor):\n            Spectrogram to be encoded as TF float tensor.\n        db_range (float):\n            Range in decibel for encoding.\n\n    Returns:\n        tf.Tensor:\n            Encoded decibel spectrogram as `uint8` tensor.\n    \"\"\"\n    db_spectrogram: tf.Tensor = gain_to_db(spectrogram)\n    max_db_spectrogram: tf.Tensor = tf.reduce_max(db_spectrogram)\n    int_db_spectrogram: tf.Tensor = tf.maximum(\n        db_spectrogram, max_db_spectrogram - db_range\n    )\n    return from_float32_to_uint8(int_db_spectrogram, **kwargs)\n\n\ndef db_uint_spectrogram_to_gain(\n    db_uint_spectrogram: tf.Tensor, min_db: tf.Tensor, max_db: tf.Tensor\n) -> tf.Tensor:\n    \"\"\"\n    Decode spectrogram from uint8 decibel scale.\n\n    Paramters:\n        db_uint_spectrogram (tf.Tensor):\n            Decibel spectrogram to decode.\n        min_db (tf.Tensor):\n            Lower bound limit for decoding.\n        max_db (tf.Tensor):\n            Upper bound limit for decoding.\n\n    Returns:\n        tf.Tensor:\n            Decoded spectrogram as `float32` tensor.\n    \"\"\"\n    db_spectrogram: tf.Tensor = from_uint8_to_float32(\n        db_uint_spectrogram, min_db, max_db\n    )\n    return db_to_gain(db_spectrogram)\n"
  },
  {
    "path": "spleeter/audio/ffmpeg.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nThis module provides an AudioAdapter implementation based on FFMPEG\nprocess. Such implementation is POSIXish and depends on nothing except\nstandard Python libraries. Thus this implementation is the default one\nused within this library.\n\"\"\"\n\nimport datetime as dt\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Dict, Optional, Union\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport ffmpeg  # type: ignore\nimport numpy as np\n\nfrom .. import SpleeterError\nfrom ..types import Signal\nfrom ..utils.logging import logger\nfrom . import Codec\nfrom .adapter import AudioAdapter\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\nclass FFMPEGProcessAudioAdapter(AudioAdapter):\n    \"\"\"\n    An AudioAdapter implementation that use FFMPEG binary through\n    subprocess in order to perform I/O operation for audio processing.\n\n    When created, FFMPEG binary path will be checked and expended,\n    raising exception if not found. Such path could be infered using\n    `FFMPEG_PATH` environment variable.\n    \"\"\"\n\n    SUPPORTED_CODECS: Dict[Codec, str] = {\n        Codec.M4A: \"aac\",\n        Codec.OGG: \"libvorbis\",\n        Codec.WMA: \"wmav2\",\n    }\n    \"\"\" FFMPEG codec name mapping. \"\"\"\n\n    def __init__(_) -> None:\n        \"\"\"\n        Default constructor, ensure FFMPEG binaries are available.\n\n        Raises:\n            SpleeterError:\n                If ffmpeg or ffprobe is not found.\n        \"\"\"\n        for binary in (\"ffmpeg\", \"ffprobe\"):\n            if shutil.which(binary) is None:\n                raise SpleeterError(\"{} binary not found\".format(binary))\n\n    def load(\n        _,\n        path: Union[Path, str],\n        offset: Optional[float] = None,\n        duration: Optional[float] = None,\n        sample_rate: Optional[float] = None,\n        dtype: bytes = b\"float32\",\n    ) -> Signal:\n        \"\"\"\n        Loads the audio file denoted by the given path\n        and returns it data as a waveform.\n\n        Parameters:\n            path (Union[Path, str]:\n                Path of the audio file to load data from.\n            offset (Optional[float]):\n                (Optional) Start offset to load from in seconds.\n            duration (Optional[float]):\n                (Optional) Duration to load in seconds.\n            sample_rate (Optional[float]):\n                (Optional) Sample rate to load audio with.\n            dtype (bytes):\n                (Optional) Data type to use, default to `b'float32'`.\n\n        Returns:\n            Signal:\n                Loaded data a (waveform, sample_rate) tuple.\n\n        Raises:\n            SpleeterError:\n                If any error occurs while loading audio.\n        \"\"\"\n        if isinstance(path, Path):\n            path = str(path)\n        if not isinstance(path, str):\n            path = path.decode()\n        try:\n            probe = ffmpeg.probe(path)\n        except ffmpeg._run.Error as e:\n            raise SpleeterError(\n                \"An error occurs with ffprobe (see ffprobe output below)\\n\\n{}\".format(\n                    e.stderr.decode()\n                )\n            )\n        if \"streams\" not in probe or len(probe[\"streams\"]) == 0:\n            raise SpleeterError(\"No stream was found with ffprobe\")\n        metadata = next(\n            stream for stream in probe[\"streams\"] if stream[\"codec_type\"] == \"audio\"\n        )\n        n_channels = metadata[\"channels\"]\n        if sample_rate is None:\n            sample_rate = metadata[\"sample_rate\"]\n        output_kwargs = {\"format\": \"f32le\", \"ar\": sample_rate}\n        if duration is not None:\n            output_kwargs[\"t\"] = str(dt.timedelta(seconds=duration))\n        if offset is not None:\n            output_kwargs[\"ss\"] = str(dt.timedelta(seconds=offset))\n        process = (\n            ffmpeg.input(path)\n            .output(\"pipe:\", **output_kwargs)\n            .run_async(pipe_stdout=True, pipe_stderr=True)\n        )\n        buffer, _ = process.communicate()\n        waveform = np.frombuffer(buffer, dtype=\"<f4\").reshape(-1, n_channels)\n        if not waveform.dtype == np.dtype(dtype):\n            waveform = waveform.astype(dtype)\n        return (waveform, sample_rate)\n\n    def save(\n        self,\n        path: Union[Path, str],\n        data: np.ndarray,\n        sample_rate: float,\n        codec: Codec = None,\n        bitrate: str = None,\n    ) -> None:\n        \"\"\"\n        Write waveform data to the file denoted by the given path using\n        FFMPEG process.\n\n        Parameters:\n            path (Union[Path, str]):\n                Path like of the audio file to save data in.\n            data (np.ndarray):\n                Waveform data to write.\n            sample_rate (float):\n                Sample rate to write file in.\n            codec (Codec):\n                (Optional) Writing codec to use, default to `None`.\n            bitrate (str):\n                (Optional) Bitrate of the written audio file, default to\n                `None`.\n\n        Raises:\n            IOError:\n                If any error occurs while using FFMPEG to write data.\n        \"\"\"\n        if isinstance(path, Path):\n            path = str(path)\n        directory = os.path.dirname(path)\n        if not os.path.exists(directory):\n            raise SpleeterError(f\"output directory does not exists: {directory}\")\n        logger.debug(f\"Writing file {path}\")\n        input_kwargs = {\"ar\": sample_rate, \"ac\": data.shape[1]}\n        output_kwargs = {\"ar\": sample_rate, \"strict\": \"-2\"}\n        if bitrate:\n            output_kwargs[\"audio_bitrate\"] = bitrate\n        if codec is not None and codec != \"wav\":\n            output_kwargs[\"codec\"] = self.SUPPORTED_CODECS.get(codec, codec)\n        process = (\n            ffmpeg.input(\"pipe:\", format=\"f32le\", **input_kwargs)\n            .output(path, **output_kwargs)\n            .overwrite_output()\n            .run_async(pipe_stdin=True, pipe_stderr=True, quiet=True)\n        )\n        try:\n            process.stdin.write(data.astype(\"<f4\").tobytes())\n            process.stdin.close()\n            process.wait()\n        except IOError:\n            raise SpleeterError(f\"FFMPEG error: {process.stderr.read()}\")\n        logger.info(f\"File {path} written succesfully\")\n"
  },
  {
    "path": "spleeter/audio/spectrogram.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Spectrogram specific data augmentation. \"\"\"\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\nfrom tensorflow.signal import hann_window, stft  # type: ignore\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef compute_spectrogram_tf(\n    waveform: tf.Tensor,\n    frame_length: int = 2048,\n    frame_step: int = 512,\n    spec_exponent: float = 1.0,\n    window_exponent: float = 1.0,\n) -> tf.Tensor:\n    \"\"\"\n    Compute magnitude / power spectrogram from waveform as a\n    `n_samples x n_channels` tensor.\n\n    Parameters:\n        waveform (tf.Tensor):\n            Input waveform as `(times x number of channels)` tensor.\n        frame_length (int):\n            (Optional) Length of a STFT frame to use.\n        frame_step (int):\n            (Optional) HOP between successive frames.\n        spec_exponent (float):\n            (Optional) Exponent of the spectrogram (usually 1 for\n            magnitude spectrogram, or 2 for power spectrogram).\n        window_exponent (float):\n            (Optional) Exponent applied to the Hann windowing function\n            (may be useful for making perfect STFT/iSTFT reconstruction).\n\n    Returns:\n        tf.Tensor:\n            Computed magnitude / power spectrogram as a\n            `(T x F x n_channels)` tensor.\n    \"\"\"\n    stft_tensor: tf.Tensor = tf.transpose(\n        stft(\n            tf.transpose(waveform),\n            frame_length,\n            frame_step,\n            window_fn=lambda f, dtype: hann_window(\n                f, periodic=True, dtype=waveform.dtype\n            )\n            ** window_exponent,\n        ),\n        perm=[1, 2, 0],\n    )\n    return tf.abs(stft_tensor) ** spec_exponent\n\n\ndef time_stretch(\n    spectrogram: tf.Tensor,\n    factor: float = 1.0,\n    method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,\n) -> tf.Tensor:\n    \"\"\"\n    Time stretch a spectrogram preserving shape in tensorflow. Note that\n    this is an approximation in the frequency domain.\n\n    Parameters:\n        spectrogram (tf.Tensor):\n            Input spectrogram to be time stretched as tensor.\n        factor (float):\n            (Optional) Time stretch factor, must be > 0, default to `1`.\n        method (tf.image.ResizeMethod):\n            (Optional) Interpolation method, default to `BILINEAR`.\n\n    Returns:\n        tf.Tensor:\n            Time stretched spectrogram as tensor with same shape.\n    \"\"\"\n    T = tf.shape(spectrogram)[0]\n    T_ts = tf.cast(tf.cast(T, tf.float32) * factor, tf.int32)[0]\n    F = tf.shape(spectrogram)[1]\n    ts_spec = tf.image.resize_images(\n        spectrogram, [T_ts, F], method=method, align_corners=True\n    )\n    return tf.image.resize_image_with_crop_or_pad(ts_spec, T, F)\n\n\ndef random_time_stretch(\n    spectrogram: tf.Tensor, factor_min: float = 0.9, factor_max: float = 1.1, **kwargs\n) -> tf.Tensor:\n    \"\"\"\n    Time stretch a spectrogram preserving shape with random ratio in\n    tensorflow. Applies time_stretch to spectrogram with a random ratio\n    drawn uniformly in `[factor_min, factor_max]`.\n\n    Parameters:\n        spectrogram (tf.Tensor):\n            Input spectrogram to be time stretched as tensor.\n        factor_min (float):\n            (Optional) Min time stretch factor, default to `0.9`.\n        factor_max (float):\n            (Optional) Max time stretch factor, default to `1.1`.\n        ** kwargs:\n            Time stretch args.\n\n    Returns:\n        tf.Tensor:\n            Randomly time stretched spectrogram as tensor with same shape.\n    \"\"\"\n    factor = (\n        tf.random_uniform(shape=(1,), seed=0) * (factor_max - factor_min) + factor_min\n    )\n    return time_stretch(spectrogram, factor=factor, **kwargs)\n\n\ndef pitch_shift(\n    spectrogram: tf.Tensor,\n    semitone_shift: float = 0.0,\n    method: tf.image.ResizeMethod = tf.image.ResizeMethod.BILINEAR,\n) -> tf.Tensor:\n    \"\"\"\n    Pitch shift a spectrogram preserving shape in tensorflow. Note that\n    this is an approximation in the frequency domain.\n\n    Parameters:\n        spectrogram (tf.Tensor):\n            Input spectrogram to be pitch shifted as tensor.\n        semitone_shift (float):\n            (Optional) Pitch shift in semitone, default to `0.0`.\n        method (tf.image.ResizeMethod):\n            (Optional) Interpolation method, default to `BILINEAR`.\n\n    Returns:\n        tf.Tensor:\n            Pitch shifted spectrogram (same shape as spectrogram).\n    \"\"\"\n    factor = 2 ** (semitone_shift / 12.0)\n    T = tf.shape(spectrogram)[0]\n    F = tf.shape(spectrogram)[1]\n    F_ps = tf.cast(tf.cast(F, tf.float32) * factor, tf.int32)[0]\n    ps_spec = tf.image.resize_images(\n        spectrogram, [T, F_ps], method=method, align_corners=True\n    )\n    paddings = [[0, 0], [0, tf.maximum(0, F - F_ps)], [0, 0]]\n    return tf.pad(ps_spec[:, :F, :], paddings, \"CONSTANT\")\n\n\ndef random_pitch_shift(\n    spectrogram: tf.Tensor, shift_min: float = -1.0, shift_max: float = 1.0, **kwargs\n) -> tf.Tensor:\n    \"\"\"\n    Pitch shift a spectrogram preserving shape with random ratio in\n    tensorflow. Applies pitch_shift to spectrogram with a random shift\n    amount (expressed in semitones) drawn uniformly in\n    `[shift_min, shift_max]`.\n\n    Parameters:\n        spectrogram (tf.Tensor):\n            Input spectrogram to be pitch shifted as tensor.\n        shift_min (float):\n            (Optional) Min pitch shift in semitone, default to -1.\n        shift_max (float):\n            (Optional) Max pitch shift in semitone, default to 1.\n\n    Returns:\n        tf.Tensor:\n            Randomly pitch shifted spectrogram (same shape as spectrogram).\n    \"\"\"\n    semitone_shift = (\n        tf.random_uniform(shape=(1,), seed=0) * (shift_max - shift_min) + shift_min\n    )\n    return pitch_shift(spectrogram, semitone_shift=semitone_shift, **kwargs)\n"
  },
  {
    "path": "spleeter/dataset.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nModule for building data preprocessing pipeline using the tensorflow\ndata API. Data preprocessing such as audio loading, spectrogram\ncomputation, cropping, feature caching or data augmentation is done\nusing a tensorflow dataset object that output a tuple (input_, output)\nwhere:\n\n-   input is a dictionary with a single key that contains the (batched)\n    mix spectrogram of audio samples\n-   output is a dictionary of spectrogram of the isolated tracks\n    (ground truth)\n\"\"\"\n\nimport os\nimport time\nfrom os.path import exists\nfrom os.path import sep as SEPARATOR\nfrom typing import Any, Dict, List, Optional, Tuple\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\n\nfrom .audio.adapter import AudioAdapter\nfrom .audio.convertor import db_uint_spectrogram_to_gain, spectrogram_to_db_uint\nfrom .audio.spectrogram import (\n    compute_spectrogram_tf,\n    random_pitch_shift,\n    random_time_stretch,\n)\nfrom .utils.logging import logger\nfrom .utils.tensor import (\n    check_tensor_shape,\n    dataset_from_csv,\n    set_tensor_shape,\n    sync_apply,\n)\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n# Default audio parameters to use.\nDEFAULT_AUDIO_PARAMS: Dict = {\n    \"instrument_list\": (\"vocals\", \"accompaniment\"),\n    \"mix_name\": \"mix\",\n    \"sample_rate\": 44100,\n    \"frame_length\": 4096,\n    \"frame_step\": 1024,\n    \"T\": 512,\n    \"F\": 1024,\n}\n\n\ndef get_training_dataset(\n    audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str\n) -> Any:\n    \"\"\"\n    Builds training dataset.\n\n    Parameters:\n        audio_params (Dict):\n            Audio parameters.\n        audio_adapter (AudioAdapter):\n            Adapter to load audio from.\n        audio_path (str):\n            Path of directory containing audio.\n\n    Returns:\n        Any:\n            Built dataset.\n    \"\"\"\n    builder = DatasetBuilder(\n        audio_params,\n        audio_adapter,\n        audio_path,\n        chunk_duration=audio_params.get(\"chunk_duration\", 20.0),\n        random_seed=audio_params.get(\"random_seed\", 0),\n    )\n    return builder.build(\n        str(audio_params.get(\"train_csv\")),\n        cache_directory=audio_params.get(\"training_cache\"),\n        batch_size=audio_params.get(\"batch_size\", 8),\n        n_chunks_per_song=audio_params.get(\"n_chunks_per_song\", 2),\n        random_data_augmentation=False,\n        convert_to_uint=True,\n        wait_for_cache=False,\n    )\n\n\ndef get_validation_dataset(\n    audio_params: Dict, audio_adapter: AudioAdapter, audio_path: str\n) -> Any:\n    \"\"\"\n    Builds validation dataset.\n\n    Parameters:\n        audio_params (Dict):\n            Audio parameters.\n        audio_adapter (AudioAdapter):\n            Adapter to load audio from.\n        audio_path (str):\n            Path of directory containing audio.\n\n    Returns:\n        Any:\n            Built dataset.\n    \"\"\"\n    builder = DatasetBuilder(\n        audio_params, audio_adapter, audio_path, chunk_duration=12.0\n    )\n    return builder.build(\n        str(audio_params.get(\"validation_csv\")),\n        batch_size=audio_params.get(\"batch_size\", 8),\n        cache_directory=audio_params.get(\"validation_cache\"),\n        convert_to_uint=True,\n        infinite_generator=False,\n        n_chunks_per_song=1,\n        # should not perform data augmentation for eval:\n        random_data_augmentation=False,\n        random_time_crop=False,\n        shuffle=False,\n    )\n\n\nclass InstrumentDatasetBuilder(object):\n    \"\"\"Instrument based filter and mapper provider.\"\"\"\n\n    def __init__(self, parent: Any, instrument: Any) -> None:\n        \"\"\"\n        Default constructor.\n\n        Parameters:\n            parent (Any):\n                Parent dataset builder.\n            instrument (Any):\n                Target instrument.\n        \"\"\"\n        self._parent = parent\n        self._instrument = instrument\n        self._spectrogram_key = f\"{instrument}_spectrogram\"\n        self._min_spectrogram_key = f\"min_{instrument}_spectrogram\"\n        self._max_spectrogram_key = f\"max_{instrument}_spectrogram\"\n\n    def load_waveform(self, sample: Dict) -> Dict:\n        \"\"\"Load waveform for given sample.\"\"\"\n        return dict(\n            sample,\n            **self._parent._audio_adapter.load_waveform(\n                sample[f\"{self._instrument}_path\"],\n                offset=sample[\"start\"],\n                duration=self._parent._chunk_duration,\n                sample_rate=self._parent._sample_rate,\n                waveform_name=\"waveform\",\n            ),\n        )\n\n    def compute_spectrogram(self, sample: Dict) -> Dict:\n        \"\"\"Compute spectrogram of the given sample.\"\"\"\n        return dict(\n            sample,\n            **{\n                self._spectrogram_key: compute_spectrogram_tf(\n                    sample[\"waveform\"],\n                    frame_length=self._parent._frame_length,\n                    frame_step=self._parent._frame_step,\n                    spec_exponent=1.0,\n                    window_exponent=1.0,\n                )\n            },\n        )\n\n    def filter_frequencies(self, sample: Dict) -> Dict:\n        return dict(\n            sample,\n            **{\n                self._spectrogram_key: sample[self._spectrogram_key][\n                    :, : self._parent._F, :\n                ]\n            },\n        )\n\n    def convert_to_uint(self, sample: Dict) -> Dict:\n        \"\"\"Convert given sample from float to unit.\"\"\"\n        return dict(\n            sample,\n            **spectrogram_to_db_uint(\n                sample[self._spectrogram_key],\n                tensor_key=self._spectrogram_key,\n                min_key=self._min_spectrogram_key,\n                max_key=self._max_spectrogram_key,\n            ),\n        )\n\n    def filter_infinity(self, sample: Dict) -> tf.Tensor:\n        \"\"\"Filter infinity sample.\"\"\"\n        return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))\n\n    def convert_to_float32(self, sample: Dict) -> Dict:\n        \"\"\"Convert given sample from unit to float.\"\"\"\n        return dict(\n            sample,\n            **{\n                self._spectrogram_key: db_uint_spectrogram_to_gain(\n                    sample[self._spectrogram_key],\n                    sample[self._min_spectrogram_key],\n                    sample[self._max_spectrogram_key],\n                )\n            },\n        )\n\n    def time_crop(self, sample: Dict) -> Dict:\n        def start(sample):\n            \"\"\"mid_segment_start\"\"\"\n            return tf.cast(\n                tf.maximum(\n                    tf.shape(sample[self._spectrogram_key])[0] / 2\n                    - self._parent._T / 2,\n                    0,\n                ),\n                tf.int32,\n            )\n\n        return dict(\n            sample,\n            **{\n                self._spectrogram_key: sample[self._spectrogram_key][\n                    start(sample) : start(sample) + self._parent._T, :, :\n                ]\n            },\n        )\n\n    def filter_shape(self, sample: Dict) -> bool:\n        \"\"\"Filter badly shaped sample.\"\"\"\n        return check_tensor_shape(\n            sample[self._spectrogram_key],\n            (self._parent._T, self._parent._F, self._parent._n_channels),\n        )\n\n    def reshape_spectrogram(self, sample: Dict) -> Dict:\n        \"\"\"Reshape given sample.\"\"\"\n        return dict(\n            sample,\n            **{\n                self._spectrogram_key: set_tensor_shape(\n                    sample[self._spectrogram_key],\n                    (self._parent._T, self._parent._F, self._parent._n_channels),\n                )\n            },\n        )\n\n\nclass DatasetBuilder(object):\n    MARGIN: float = 0.5\n    \"\"\"Margin at beginning and end of songs in seconds.\"\"\"\n\n    WAIT_PERIOD: int = 60\n    \"\"\"Wait period for cache (in seconds).\"\"\"\n\n    def __init__(\n        self,\n        audio_params: Dict,\n        audio_adapter: AudioAdapter,\n        audio_path: str,\n        random_seed: int = 0,\n        chunk_duration: float = 20.0,\n    ) -> None:\n        \"\"\"\n        Default constructor.\n        \"\"\"\n        # Length of segment in frames (if fs=22050 and\n        # frame_step=512, then T=512 corresponds to 11.89s)\n        self._T = audio_params[\"T\"]\n        # Number of frequency bins to be used (should\n        # be less than frame_length/2 + 1)\n        self._F = audio_params[\"F\"]\n        self._sample_rate = audio_params[\"sample_rate\"]\n        self._frame_length = audio_params[\"frame_length\"]\n        self._frame_step = audio_params[\"frame_step\"]\n        self._mix_name = audio_params[\"mix_name\"]\n        self._n_channels = audio_params[\"n_channels\"]\n        self._instruments = [self._mix_name] + audio_params[\"instrument_list\"]\n        self._instrument_builders: Optional[List] = None\n        self._chunk_duration = chunk_duration\n        self._audio_adapter = audio_adapter\n        self._audio_params = audio_params\n        self._audio_path = audio_path\n        self._random_seed = random_seed\n\n        self.check_parameters_compatibility()\n\n    def check_parameters_compatibility(self):\n        if self._frame_length / 2 + 1 < self._F:\n            raise ValueError(\n                \"F is too large and must be set to at most frame_length/2+1. \"\n                \"Decrease F or increase frame_length to fix.\"\n            )\n\n        if (\n            self._chunk_duration * self._sample_rate - self._frame_length\n        ) / self._frame_step < self._T:\n            raise ValueError(\n                \"T is too large considering STFT parameters and chunk duratoin. \"\n                \"Make sure spectrogram time dimension of chunks is larger than T \"\n                \"(for instance reducing T or frame_step or increasing chunk duration).\"\n            )\n\n    def expand_path(self, sample: Dict) -> Dict:\n        \"\"\"Expands audio paths for the given sample.\"\"\"\n        return dict(\n            sample,\n            **{\n                f\"{instrument}_path\": tf.strings.join(\n                    (self._audio_path, sample[f\"{instrument}_path\"]), SEPARATOR\n                )\n                for instrument in self._instruments\n            },\n        )\n\n    def filter_error(self, sample: Dict) -> tf.Tensor:\n        \"\"\"Filter errored sample.\"\"\"\n        return tf.logical_not(sample[\"waveform_error\"])\n\n    def filter_waveform(self, sample: Dict) -> Dict:\n        \"\"\"Filter waveform from sample.\"\"\"\n        return {k: v for k, v in sample.items() if not k == \"waveform\"}\n\n    def harmonize_spectrogram(self, sample: Dict) -> Dict:\n        \"\"\"Ensure same size for vocals and mix spectrograms.\"\"\"\n\n        def _reduce(sample):\n            return tf.reduce_min(\n                [\n                    tf.shape(sample[f\"{instrument}_spectrogram\"])[0]\n                    for instrument in self._instruments\n                ]\n            )\n\n        return dict(\n            sample,\n            **{\n                f\"{instrument}_spectrogram\": sample[f\"{instrument}_spectrogram\"][\n                    : _reduce(sample), :, :\n                ]\n                for instrument in self._instruments\n            },\n        )\n\n    def filter_short_segments(self, sample: Dict) -> tf.Tensor:\n        \"\"\"Filter out too short segment.\"\"\"\n        return tf.reduce_any(\n            [\n                tf.shape(sample[f\"{instrument}_spectrogram\"])[0] >= self._T\n                for instrument in self._instruments\n            ]\n        )\n\n    def random_time_crop(self, sample: Dict) -> Dict:\n        \"\"\"Random time crop of 11.88s.\"\"\"\n        return dict(\n            sample,\n            **sync_apply(\n                {\n                    f\"{instrument}_spectrogram\": sample[f\"{instrument}_spectrogram\"]\n                    for instrument in self._instruments\n                },\n                lambda x: tf.image.random_crop(\n                    x,\n                    (self._T, len(self._instruments) * self._F, self._n_channels),\n                    seed=self._random_seed,\n                ),\n            ),\n        )\n\n    def random_time_stretch(self, sample: Dict) -> Dict:\n        \"\"\"Randomly time stretch the given sample.\"\"\"\n        return dict(\n            sample,\n            **sync_apply(\n                {\n                    f\"{instrument}_spectrogram\": sample[f\"{instrument}_spectrogram\"]\n                    for instrument in self._instruments\n                },\n                lambda x: random_time_stretch(x, factor_min=0.9, factor_max=1.1),\n            ),\n        )\n\n    def random_pitch_shift(self, sample: Dict) -> Dict:\n        \"\"\"Randomly pitch shift the given sample.\"\"\"\n        return dict(\n            sample,\n            **sync_apply(\n                {\n                    f\"{instrument}_spectrogram\": sample[f\"{instrument}_spectrogram\"]\n                    for instrument in self._instruments\n                },\n                lambda x: random_pitch_shift(x, shift_min=-1.0, shift_max=1.0),\n                concat_axis=0,\n            ),\n        )\n\n    def map_features(self, sample: Dict) -> Tuple[Dict, Dict]:\n        \"\"\"Select features and annotation of the given sample.\"\"\"\n        input_ = {\n            f\"{self._mix_name}_spectrogram\": sample[f\"{self._mix_name}_spectrogram\"]\n        }\n        output = {\n            f\"{instrument}_spectrogram\": sample[f\"{instrument}_spectrogram\"]\n            for instrument in self._audio_params[\"instrument_list\"]\n        }\n        return (input_, output)\n\n    def compute_segments(self, dataset: Any, n_chunks_per_song: int) -> Any:\n        \"\"\"\n        Computes segments for each song of the dataset.\n\n        Parameters:\n            dataset (Any):\n                Dataset to compute segments for.\n            n_chunks_per_song (int):\n                Number of segment per song to compute.\n\n        Returns:\n            Any:\n                Segmented dataset.\n        \"\"\"\n        if n_chunks_per_song <= 0:\n            raise ValueError(\"n_chunks_per_song must be positif\")\n        datasets = []\n        for k in range(n_chunks_per_song):\n            if n_chunks_per_song > 1:\n                datasets.append(\n                    dataset.map(\n                        lambda sample: dict(\n                            sample,\n                            start=tf.maximum(\n                                k\n                                * (\n                                    sample[\"duration\"]\n                                    - self._chunk_duration\n                                    - 2 * self.MARGIN\n                                )\n                                / (n_chunks_per_song - 1)\n                                + self.MARGIN,\n                                0,\n                            ),\n                        )\n                    )\n                )\n            elif n_chunks_per_song == 1:  # Take central segment.\n                datasets.append(\n                    dataset.map(\n                        lambda sample: dict(\n                            sample,\n                            start=tf.maximum(\n                                sample[\"duration\"] / 2 - self._chunk_duration / 2, 0\n                            ),\n                        )\n                    )\n                )\n        dataset = datasets[-1]\n        for d in datasets[:-1]:\n            dataset = dataset.concatenate(d)\n        return dataset\n\n    @property\n    def instruments(self) -> Any:\n        \"\"\"\n        Instrument dataset builder generator.\n\n        Yields:\n            Any:\n                InstrumentBuilder instance.\n        \"\"\"\n        if self._instrument_builders is None:\n            self._instrument_builders = []\n            for instrument in self._instruments:\n                self._instrument_builders.append(\n                    InstrumentDatasetBuilder(self, instrument)\n                )\n        for builder in self._instrument_builders:\n            yield builder\n\n    def cache(self, dataset: Any, cache: Optional[str], wait: bool) -> Any:\n        \"\"\"\n        Cache the given dataset if cache is enabled. Eventually waits for\n        cache to be available (useful if another process is already\n        computing cache) if provided wait flag is `True`.\n\n        Parameters:\n            dataset (Any):\n                Dataset to be cached if cache is required.\n            cache (str):\n                Path of cache directory to be used, None if no cache.\n            wait (bool):\n                If caching is enabled, True is cache should be waited.\n\n        Returns:\n            Any:\n                Cached dataset if needed, original dataset otherwise.\n        \"\"\"\n        if cache is not None:\n            if wait:\n                while not exists(f\"{cache}.index\"):\n                    logger.info(f\"Cache not available, wait {self.WAIT_PERIOD}\")\n                    time.sleep(self.WAIT_PERIOD)\n            cache_path = os.path.split(cache)[0]\n            os.makedirs(cache_path, exist_ok=True)\n            return dataset.cache(cache)\n        return dataset\n\n    def build(\n        self,\n        csv_path: str,\n        batch_size: int = 8,\n        shuffle: bool = True,\n        convert_to_uint: bool = True,\n        random_data_augmentation: bool = False,\n        random_time_crop: bool = True,\n        infinite_generator: bool = True,\n        cache_directory: Optional[str] = None,\n        wait_for_cache: bool = False,\n        num_parallel_calls: int = 4,\n        n_chunks_per_song: int = 2,\n    ) -> Any:\n        dataset = dataset_from_csv(csv_path)\n        dataset = self.compute_segments(dataset, n_chunks_per_song)\n        # Shuffle data\n        if shuffle:\n            dataset = dataset.shuffle(\n                buffer_size=200000,\n                seed=self._random_seed,\n                # useless since it is cached :\n                reshuffle_each_iteration=True,\n            )\n        # Expand audio path.\n        dataset = dataset.map(self.expand_path)\n        # Load waveform, compute spectrogram, and filtering error,\n        # K bins frequencies, and waveform.\n        N = num_parallel_calls\n        for instrument in self.instruments:\n            dataset = (\n                dataset.map(instrument.load_waveform, num_parallel_calls=N)\n                .filter(self.filter_error)\n                .map(instrument.compute_spectrogram, num_parallel_calls=N)\n                .map(instrument.filter_frequencies)\n            )\n        dataset = dataset.map(self.filter_waveform)\n        # Convert to uint before caching in order to save space.\n        if convert_to_uint:\n            for instrument in self.instruments:\n                dataset = dataset.map(instrument.convert_to_uint)\n        dataset = self.cache(dataset, cache_directory, wait_for_cache)\n        # Check for INFINITY (should not happen)\n        for instrument in self.instruments:\n            dataset = dataset.filter(instrument.filter_infinity)\n        # Repeat indefinitly\n        if infinite_generator:\n            dataset = dataset.repeat(count=-1)\n        # Ensure same size for vocals and mix spectrograms.\n        # NOTE: could be done before caching ?\n        dataset = dataset.map(self.harmonize_spectrogram)\n        # Filter out too short segment.\n        # NOTE: could be done before caching ?\n        dataset = dataset.filter(self.filter_short_segments)\n        # Random time crop of 11.88s\n        if random_time_crop:\n            dataset = dataset.map(self.random_time_crop, num_parallel_calls=N)\n        else:\n            # frame_duration = 11.88/T\n            # take central segment (for validation)\n            for instrument in self.instruments:\n                dataset = dataset.map(instrument.time_crop)\n        # Post cache shuffling. Done where the data are the lightest:\n        # after croping but before converting back to float.\n        if shuffle:\n            dataset = dataset.shuffle(\n                buffer_size=256, seed=self._random_seed, reshuffle_each_iteration=True\n            )\n        # Convert back to float32\n        if convert_to_uint:\n            for instrument in self.instruments:\n                dataset = dataset.map(\n                    instrument.convert_to_float32, num_parallel_calls=N\n                )\n        M = 8  # Parallel call post caching.\n        # Must be applied with the same factor on mix and vocals.\n        if random_data_augmentation:\n            dataset = dataset.map(self.random_time_stretch, num_parallel_calls=M).map(\n                self.random_pitch_shift, num_parallel_calls=M\n            )\n        # Filter by shape (remove badly shaped tensors).\n        for instrument in self.instruments:\n            dataset = dataset.filter(instrument.filter_shape).map(\n                instrument.reshape_spectrogram\n            )\n        # Select features and annotation.\n        dataset = dataset.map(self.map_features)\n        # Make batch (done after selection to avoid\n        # error due to unprocessed instrument spectrogram batching).\n        dataset = dataset.batch(batch_size)\n        return dataset\n"
  },
  {
    "path": "spleeter/model/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" This package provide an estimator builder as well as model functions. \"\"\"\n\nimport importlib\nfrom typing import Any, Dict, Optional, Tuple\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\nfrom tensorflow.signal import hann_window, inverse_stft, stft  # type: ignore\n\nfrom ..utils.tensor import pad_and_partition, pad_and_reshape\n\n# pylint: enable=import-error\n\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\nplaceholder = tf.compat.v1.placeholder\n\n\ndef get_model_function(model_type):\n    \"\"\"\n    Get tensorflow function of the model to be applied to the input tensor.\n    For instance \"unet.softmax_unet\" will return the softmax_unet function\n    in the \"unet.py\" submodule of the current module (spleeter.model).\n\n    Parameters:\n        model_type (str):\n            The relative module path to the model function.\n\n    Returns:\n        Function:\n            A tensorflow function to be applied to the input tensor to get the\n            multitrack output.\n    \"\"\"\n    relative_path_to_module = \".\".join(model_type.split(\".\")[:-1])\n    model_name = model_type.split(\".\")[-1]\n    main_module = \".\".join((__name__, \"functions\"))\n    path_to_module = f\"{main_module}.{relative_path_to_module}\"\n    module = importlib.import_module(path_to_module)\n    model_function = getattr(module, model_name)\n    return model_function\n\n\nclass InputProvider(object):\n    def __init__(self, params):\n        self.params = params\n\n    def get_input_dict_placeholders(self):\n        raise NotImplementedError()\n\n    @property\n    def input_names(self):\n        raise NotImplementedError()\n\n    def get_feed_dict(self, features, *args):\n        raise NotImplementedError()\n\n\nclass WaveformInputProvider(InputProvider):\n    @property\n    def input_names(self):\n        return [\"audio_id\", \"waveform\"]\n\n    def get_input_dict_placeholders(self):\n        shape = (None, self.params[\"n_channels\"])\n        features = {\n            \"waveform\": placeholder(tf.float32, shape=shape, name=\"waveform\"),\n            \"audio_id\": placeholder(tf.string, name=\"audio_id\"),\n        }\n        return features\n\n    def get_feed_dict(self, features, waveform, audio_id):\n        return {features[\"audio_id\"]: audio_id, features[\"waveform\"]: waveform}\n\n\nclass InputProviderFactory(object):\n    @staticmethod\n    def get(params):\n        return WaveformInputProvider(params)\n\n\nclass EstimatorSpecBuilder(object):\n    \"\"\"\n    A builder class that allows to builds a multitrack unet model\n    estimator. The built model estimator has a different behaviour when\n    used in a train/eval mode and in predict mode.\n\n    * In train/eval mode:\n        It takes as input and outputs magnitude spectrogram\n    * In predict mode:\n        It takes as input and outputs waveform.\n        The whole separation process is then done in this function\n        for performance reason: it makes it possible to run the whole\n        separation process (including STFT and inverse STFT) on GPU.\n\n    Example:\n    >>> from spleeter.model import EstimatorSpecBuilder\n    >>> builder = EstimatorSpecBuilder()\n    >>> builder.build_predict_model()\n    >>> builder.build_evaluation_model()\n    >>> builder.build_train_model()\n\n    >>> from spleeter.model import model_fn\n    >>> estimator = tf.estimator.Estimator(model_fn=model_fn, ...)\n    \"\"\"\n\n    # Supported model functions.\n    DEFAULT_MODEL = \"unet.unet\"\n\n    # Supported loss functions.\n    L1_MASK = \"L1_mask\"\n    WEIGHTED_L1_MASK = \"weighted_L1_mask\"\n\n    # Supported optimizers.\n    ADADELTA = \"Adadelta\"\n    SGD = \"SGD\"\n\n    # Math constants.\n    WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0\n    EPSILON = 1e-10\n\n    def __init__(self, features: Dict, params: Dict) -> None:\n        \"\"\"\n        Default constructor. Depending on built model usage,\n        the provided features should be different:\n\n        * In train/eval mode:\n            Features is a dictionary with a \"mix_spectrogram\" key,\n            associated to the mix magnitude spectrogram.\n        * In predict mode:\n            Features is a dictionary with a \"waveform\" key,\n            associated to the waveform of the sound to be separated.\n\n        Parameters:\n            features ():\n                The input features for the estimator.\n            params ():\n                Some hyperparameters as a dictionary.\n        \"\"\"\n\n        self._features = features\n        self._params = params\n        # Get instrument name.\n        self._mix_name = params[\"mix_name\"]\n        self._instruments = params[\"instrument_list\"]\n        # Get STFT/signals parameters\n        self._n_channels = params[\"n_channels\"]\n        self._T = params[\"T\"]\n        self._F = params[\"F\"]\n        self._frame_length = params[\"frame_length\"]\n        self._frame_step = params[\"frame_step\"]\n\n    def _build_model_outputs(self):\n        \"\"\"\n        Created a batch_sizexTxFxn_channels input tensor containing\n        mix magnitude spectrogram, then an output dict from it\n        according to the selected model in internal parameters.\n\n        Raises:\n            ValueError:\n                If required model_type is not supported.\n        \"\"\"\n        input_tensor = self.spectrogram_feature\n        model = self._params.get(\"model\", None)\n        if model is not None:\n            model_type = model.get(\"type\", self.DEFAULT_MODEL)\n        else:\n            model_type = self.DEFAULT_MODEL\n        try:\n            apply_model = get_model_function(model_type)\n        except ModuleNotFoundError:\n            raise ValueError(f\"No model function {model_type} found\")\n        self._model_outputs = apply_model(\n            input_tensor, self._instruments, self._params[\"model\"][\"params\"]\n        )\n\n    def _build_loss(self, labels: Dict) -> Tuple[tf.Tensor, Dict]:\n        \"\"\"\n        Construct tensorflow loss and metrics\n\n        Parameters:\n            labels (Dict):\n                Dictionary of target outputs (key: instrument name,\n                value: ground truth spectrogram of the instrument)\n\n        Returns:\n            Tuple[tf.Tensor, Dict]:\n                Tensorflow (loss, metrics) tuple.\n        \"\"\"\n        output_dict = self.model_outputs\n        loss_type = self._params.get(\"loss_type\", self.L1_MASK)\n        if loss_type == self.L1_MASK:\n            losses = {\n                name: tf.reduce_mean(tf.abs(output - labels[name]))\n                for name, output in output_dict.items()\n            }\n        elif loss_type == self.WEIGHTED_L1_MASK:\n            losses = {\n                name: tf.reduce_mean(\n                    tf.reduce_mean(labels[name], axis=[1, 2, 3], keep_dims=True)\n                    * tf.abs(output - labels[name])\n                )\n                for name, output in output_dict.items()\n            }\n        else:\n            raise ValueError(f\"Unkwnown loss type: {loss_type}\")\n        loss = tf.reduce_sum(list(losses.values()))\n        # Add metrics for monitoring each instrument.\n        metrics = {k: tf.compat.v1.metrics.mean(v) for k, v in losses.items()}\n        metrics[\"absolute_difference\"] = tf.compat.v1.metrics.mean(loss)\n        return loss, metrics\n\n    def _build_optimizer(self) -> tf.Tensor:\n        \"\"\"\n        Builds an optimizer instance from internal parameter values.\n        Default to AdamOptimizer if not specified.\n\n        Returns:\n            tf.Tensor:\n                Optimizer instance from internal configuration.\n        \"\"\"\n        name = self._params.get(\"optimizer\")\n        if name == self.ADADELTA:\n            return tf.compat.v1.train.AdadeltaOptimizer()\n        rate = self._params[\"learning_rate\"]\n        if name == self.SGD:\n            return tf.compat.v1.train.GradientDescentOptimizer(rate)\n        return tf.compat.v1.train.AdamOptimizer(rate)\n\n    @property\n    def instruments(self):\n        return self._instruments\n\n    @property\n    def stft_name(self):\n        return f\"{self._mix_name}_stft\"\n\n    @property\n    def spectrogram_name(self):\n        return f\"{self._mix_name}_spectrogram\"\n\n    def _build_stft_feature(self):\n        \"\"\"\n        Compute STFT of waveform and slice the STFT in segment\n        with the right length to feed the network.\n        \"\"\"\n        stft_name = self.stft_name\n        spec_name = self.spectrogram_name\n\n        if stft_name not in self._features:\n            # pad input with a frame of zeros\n            waveform = tf.concat(\n                [\n                    tf.zeros((self._frame_length, self._n_channels)),\n                    self._features[\"waveform\"],\n                ],\n                0,\n            )\n            stft_feature = tf.transpose(\n                stft(\n                    tf.transpose(waveform),\n                    self._frame_length,\n                    self._frame_step,\n                    window_fn=lambda frame_length, dtype: (\n                        hann_window(frame_length, periodic=True, dtype=dtype)\n                    ),\n                    pad_end=True,\n                ),\n                perm=[1, 2, 0],\n            )\n            self._features[f\"{self._mix_name}_stft\"] = stft_feature\n        if spec_name not in self._features:\n            self._features[spec_name] = tf.abs(\n                pad_and_partition(self._features[stft_name], self._T)\n            )[:, :, : self._F, :]\n\n    @property\n    def model_outputs(self):\n        if not hasattr(self, \"_model_outputs\"):\n            self._build_model_outputs()\n        return self._model_outputs\n\n    @property\n    def outputs(self):\n        if not hasattr(self, \"_outputs\"):\n            self._build_outputs()\n        return self._outputs\n\n    @property\n    def stft_feature(self):\n        if self.stft_name not in self._features:\n            self._build_stft_feature()\n        return self._features[self.stft_name]\n\n    @property\n    def spectrogram_feature(self):\n        if self.spectrogram_name not in self._features:\n            self._build_stft_feature()\n        return self._features[self.spectrogram_name]\n\n    @property\n    def masks(self):\n        if not hasattr(self, \"_masks\"):\n            self._build_masks()\n        return self._masks\n\n    @property\n    def masked_stfts(self):\n        if not hasattr(self, \"_masked_stfts\"):\n            self._build_masked_stfts()\n        return self._masked_stfts\n\n    def _inverse_stft(\n        self, stft_t: tf.Tensor, time_crop: Optional[Any] = None\n    ) -> tf.Tensor:\n        \"\"\"\n        Inverse and reshape the given STFT\n\n        Parameters:\n            stft_t (tf.Tensor):\n                Input STFT.\n            time_crop (Optional[Any]):\n                Time cropping.\n\n        Returns:\n            tf.Tensor:\n                Inverse STFT (waveform).\n        \"\"\"\n        inversed = (\n            inverse_stft(\n                tf.transpose(stft_t, perm=[2, 0, 1]),\n                self._frame_length,\n                self._frame_step,\n                window_fn=lambda frame_length, dtype: (\n                    hann_window(frame_length, periodic=True, dtype=dtype)\n                ),\n            )\n            * self.WINDOW_COMPENSATION_FACTOR\n        )\n        reshaped = tf.transpose(inversed)\n        if time_crop is None:\n            time_crop = tf.shape(self._features[\"waveform\"])[0]\n        return reshaped[self._frame_length : self._frame_length + time_crop, :]\n\n    def _build_mwf_output_waveform(self) -> Dict:\n        \"\"\"\n        Perform separation with multichannel Wiener Filtering using Norbert.\n\n        Note: multichannel Wiener Filtering is not coded in Tensorflow\n        and thus may be quite slow.\n\n        Returns:\n            Dict:\n                Dictionary of separated waveforms (key: instrument name,\n                value: estimated waveform of the instrument)\n        \"\"\"\n        import norbert  # type: ignore # pylint: disable=import-error\n\n        output_dict = self.model_outputs\n        x = self.stft_feature\n        v = tf.stack(\n            [\n                pad_and_reshape(\n                    output_dict[f\"{instrument}_spectrogram\"],\n                    self._frame_length,\n                    self._F,\n                )[: tf.shape(x)[0], ...]\n                for instrument in self._instruments\n            ],\n            axis=3,\n        )\n        input_args = [v, x]\n        stft_function = (\n            tf.py_function(\n                lambda v, x: norbert.wiener(v.numpy(), x.numpy()),\n                input_args,\n                tf.complex64,\n            ),\n        )\n        return {\n            instrument: self._inverse_stft(stft_function[0][:, :, :, k])\n            for k, instrument in enumerate(self._instruments)\n        }\n\n    def _extend_mask(self, mask: tf.Tensor) -> tf.Tensor:\n        \"\"\"\n        Extend mask, from reduced number of frequency bin to\n        the number of frequency bin in the STFT.\n\n        Parameters:\n            mask (tf.Tensor):\n                Restricted mask.\n\n        Returns:\n            tf.Tensor:\n                Extended mask\n\n        Raises:\n            ValueError:\n                If invalid mask_extension parameter is set.\n        \"\"\"\n        extension = self._params[\"mask_extension\"]\n        # Extend with average\n        # (dispatch according to energy in the processed band)\n        if extension == \"average\":\n            extension_row = tf.reduce_mean(mask, axis=2, keepdims=True)\n        # Extend with 0\n        # (avoid extension artifacts but not conservative separation)\n        elif extension == \"zeros\":\n            mask_shape = tf.shape(mask)\n            extension_row = tf.zeros((mask_shape[0], mask_shape[1], 1, mask_shape[-1]))\n        else:\n            raise ValueError(f\"Invalid mask_extension parameter {extension}\")\n        n_extra_row = self._frame_length // 2 + 1 - self._F\n        extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])\n        return tf.concat([mask, extension], axis=2)\n\n    def _build_masks(self):\n        \"\"\"\n        Compute masks from the output spectrograms of the model.\n        \"\"\"\n        output_dict = self.model_outputs\n        stft_feature = self.stft_feature\n        separation_exponent = self._params[\"separation_exponent\"]\n        output_sum = (\n            tf.reduce_sum(\n                [e ** separation_exponent for e in output_dict.values()], axis=0\n            )\n            + self.EPSILON\n        )\n        out = {}\n        for instrument in self._instruments:\n            output = output_dict[f\"{instrument}_spectrogram\"]\n            # Compute mask with the model.\n            instrument_mask = (\n                output ** separation_exponent + (self.EPSILON / len(output_dict))\n            ) / output_sum\n            # Extend mask;\n            instrument_mask = self._extend_mask(instrument_mask)\n            # Stack back mask.\n            old_shape = tf.shape(instrument_mask)\n            new_shape = tf.concat(\n                [[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0\n            )\n            instrument_mask = tf.reshape(instrument_mask, new_shape)\n            # Remove padded part (for mask having the same size as STFT);\n\n            instrument_mask = instrument_mask[: tf.shape(stft_feature)[0], ...]\n            out[instrument] = instrument_mask\n        self._masks = out\n\n    def _build_masked_stfts(self):\n        input_stft = self.stft_feature\n        out = {}\n        for instrument, mask in self.masks.items():\n            out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft\n        self._masked_stfts = out\n\n    def _build_manual_output_waveform(self, masked_stft: Dict) -> Dict:\n        \"\"\"\n        Perform ratio mask separation\n\n        Parameters:\n            masked_stft (Dict):\n                Dictionary of estimated spectrogram (key: instrument name,\n                value: estimated spectrogram of the instrument).\n\n        Returns:\n            Dict:\n                Dictionary of separated waveforms (key: instrument name,\n                value: estimated waveform of the instrument).\n        \"\"\"\n        output_waveform = {}\n        for instrument, stft_data in masked_stft.items():\n            output_waveform[instrument] = self._inverse_stft(stft_data)\n        return output_waveform\n\n    def _build_output_waveform(self, masked_stft: Dict) -> Dict:\n        \"\"\"\n        Build output waveform from given output dict in order\n        to be used in prediction context. The configuration\n        building method will be using MWF.\n\n        masked_stft (Dict):\n                Dictionary of estimated spectrogram (key: instrument name,\n                value: estimated spectrogram of the instrument).\n\n        Returns:\n            Dict:\n                Built output waveform.\n        \"\"\"\n\n        if self._params.get(\"MWF\", False):\n            output_waveform = self._build_mwf_output_waveform()\n        else:\n            output_waveform = self._build_manual_output_waveform(masked_stft)\n        return output_waveform\n\n    def _build_outputs(self):\n        self._outputs = self._build_output_waveform(self.masked_stfts)\n\n        if \"audio_id\" in self._features:\n            self._outputs[\"audio_id\"] = self._features[\"audio_id\"]\n\n    def build_predict_model(self) -> tf.Tensor:\n        \"\"\"\n        Builder interface for creating model instance that aims to perform\n        prediction / inference over given track. The output of such estimator\n        will be a dictionary with a \"<instrument>\" key per separated instrument,\n        associated to the estimated separated waveform of the instrument.\n\n        Returns:\n            tf.Tensor:\n                An estimator for performing prediction.\n        \"\"\"\n\n        return tf.estimator.EstimatorSpec(\n            tf.estimator.ModeKeys.PREDICT, predictions=self.outputs\n        )\n\n    def build_evaluation_model(self, labels: Dict) -> tf.Tensor:\n        \"\"\"\n        Builder interface for creating model instance that aims\n        to perform model evaluation. The output of such estimator\n        will be a dictionary with a key \"<instrument>_spectrogram\"\n        per separated instrument, associated to the estimated\n        separated instrument magnitude spectrogram.\n\n        Parameters:\n            labels (Dict):\n                Model labels.\n\n        Returns:\n            tf.Tensor:\n                An estimator for performing model evaluation.\n        \"\"\"\n        loss, metrics = self._build_loss(labels)\n        return tf.estimator.EstimatorSpec(\n            tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics\n        )\n\n    def build_train_model(self, labels: Dict) -> tf.Tensor:\n        \"\"\"\n        Builder interface for creating model instance that aims to perform\n        model training. The output of such estimator will be a dictionary\n        with a key \"<instrument>_spectrogram\" per separated instrument,\n        associated to the estimated separated instrument magnitude spectrogram.\n\n        Parameters:\n            labels (Dict):\n                Model labels.\n\n        Returns:\n            tf.Tensor:\n                An estimator for performing model training.\n        \"\"\"\n        loss, metrics = self._build_loss(labels)\n        optimizer = self._build_optimizer()\n        train_operation = optimizer.minimize(\n            loss=loss, global_step=tf.compat.v1.train.get_global_step()\n        )\n        return tf.estimator.EstimatorSpec(\n            mode=tf.estimator.ModeKeys.TRAIN,\n            loss=loss,\n            train_op=train_operation,\n            eval_metric_ops=metrics,\n        )\n\n\ndef model_fn(features, labels, mode, params):\n    builder = EstimatorSpecBuilder(features, params)\n    if mode == tf.estimator.ModeKeys.PREDICT:\n        return builder.build_predict_model()\n    elif mode == tf.estimator.ModeKeys.EVAL:\n        return builder.build_evaluation_model(labels)\n    elif mode == tf.estimator.ModeKeys.TRAIN:\n        return builder.build_train_model(labels)\n    raise ValueError(f\"Unknown mode {mode}\")\n"
  },
  {
    "path": "spleeter/model/functions/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" This package provide model functions. \"\"\"\n\nfrom typing import Callable, Dict, Iterable, Optional\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef apply(\n    function: Callable,\n    input_tensor: tf.Tensor,\n    instruments: Iterable[str],\n    params: Optional[Dict] = None,\n) -> Dict:\n    \"\"\"\n    Apply given function to the input tensor.\n\n    Parameters:\n        function (Callable):\n            Function to be applied to tensor.\n        input_tensor (tf.Tensor):\n            Tensor to apply blstm to.\n        instruments (Iterable[str]):\n            Iterable that provides a collection of instruments.\n        params (Optional[Dict]):\n            (Optional) dict of BLSTM parameters.\n\n    Returns:\n        Dict:\n            Created output tensor dict.\n    \"\"\"\n    output_dict: Dict = {}\n    for instrument in instruments:\n        out_name = f\"{instrument}_spectrogram\"\n        output_dict[out_name] = function(\n            input_tensor, output_name=out_name, params=params or {}\n        )\n    return output_dict\n"
  },
  {
    "path": "spleeter/model/functions/blstm.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nThis system (UHL1) uses a bi-directional LSTM network as described in :\n\n`S. Uhlich, M. Porcu, F. Giron, M. Enenkl, T. Kemp, N. Takahashi and\nY. Mitsufuji.\n\n\"Improving music source separation based on deep neural networks through\ndata augmentation and network blending\", Proc. ICASSP, 2017.`\n\nIt has three BLSTM layers, each having 500 cells.  For each instrument,\na network is trained which predicts the target instrument amplitude from\nthe mixture amplitude in the STFT domain (frame size: 4096, hop size:\n1024). The raw output of each network is then combined by a multichannel\nWiener filter. The network is trained on musdb where we split train into\ntrain_train and train_valid with 86 and 14 songs, respectively. The\nvalidation set is used to perform early stopping and hyperparameter\nselection (LSTM layer dropout rate, regularization strength).\n\"\"\"\n\nfrom typing import Dict, Optional\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\nfrom tensorflow.compat.v1.keras.initializers import he_uniform  # type: ignore\nfrom tensorflow.compat.v1.keras.layers import CuDNNLSTM  # type: ignore\nfrom tensorflow.keras.layers import (  # type: ignore\n    Bidirectional,\n    Dense,\n    Flatten,\n    Reshape,\n    TimeDistributed,\n)\n\nfrom . import apply\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef apply_blstm(\n    input_tensor: tf.Tensor, output_name: str = \"output\", params: Optional[Dict] = None\n) -> tf.Tensor:\n    \"\"\"\n    Apply BLSTM to the given input_tensor.\n\n    Parameters:\n        input_tensor (tf.Tensor):\n            Input of the model.\n        output_name (str):\n            (Optional) name of the output, default to 'output'.\n        params (Optional[Dict]):\n            (Optional) dict of BLSTM parameters.\n\n    Returns:\n        tf.Tensor:\n            Output tensor.\n    \"\"\"\n    if params is None:\n        params = {}\n    units: int = params.get(\"lstm_units\", 250)\n    kernel_initializer = he_uniform(seed=50)\n    flatten_input = TimeDistributed(Flatten())((input_tensor))\n\n    def create_bidirectional():\n        return Bidirectional(\n            CuDNNLSTM(\n                units, kernel_initializer=kernel_initializer, return_sequences=True\n            )\n        )\n\n    l1 = create_bidirectional()((flatten_input))\n    l2 = create_bidirectional()((l1))\n    l3 = create_bidirectional()((l2))\n    dense = TimeDistributed(\n        Dense(\n            int(flatten_input.shape[2]),\n            activation=\"relu\",\n            kernel_initializer=kernel_initializer,\n        )\n    )((l3))\n    output: tf.Tensor = TimeDistributed(\n        Reshape(input_tensor.shape[2:]), name=output_name\n    )(dense)\n    return output\n\n\ndef blstm(\n    input_tensor: tf.Tensor, output_name: str = \"output\", params: Optional[Dict] = None\n) -> tf.Tensor:\n    \"\"\"Model function applier.\"\"\"\n    return apply(apply_blstm, input_tensor, output_name, params)\n"
  },
  {
    "path": "spleeter/model/functions/unet.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nThis module contains building functions for U-net source\nseparation models in a similar way as in A. Jansson et al. :\n\n\"Singing voice separation with deep u-net convolutional networks\",\nISMIR 2017\n\nEach instrument is modeled by a single U-net\nconvolutional / deconvolutional network that take a mix spectrogram\nas input and the estimated sound spectrogram as output.\n\"\"\"\n\nfrom functools import partial\nfrom typing import Any, Dict, Iterable, Optional\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\nfrom tensorflow.compat.v1 import logging  # type: ignore\nfrom tensorflow.compat.v1.keras.initializers import he_uniform  # type: ignore\nfrom tensorflow.keras.layers import (  # type: ignore\n    ELU,\n    BatchNormalization,\n    Concatenate,\n    Conv2D,\n    Conv2DTranspose,\n    Dropout,\n    LeakyReLU,\n    Multiply,\n    ReLU,\n    Softmax,\n)\n\nfrom . import apply\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef _get_conv_activation_layer(params: Dict) -> Any:\n    \"\"\"\n    Parameters:\n        params (Dict):\n            Model parameters.\n\n    Returns:\n        Any:\n            Required Activation function.\n    \"\"\"\n    conv_activation: str = str(params.get(\"conv_activation\"))\n    if conv_activation == \"ReLU\":\n        return ReLU()\n    elif conv_activation == \"ELU\":\n        return ELU()\n    return LeakyReLU(0.2)\n\n\ndef _get_deconv_activation_layer(params: Dict) -> Any:\n    \"\"\"\n    Parameters:\n        params (Dict):\n            Model parameters.\n\n    Returns:\n        Any:\n            Required Activation function.\n    \"\"\"\n    deconv_activation: str = str(params.get(\"deconv_activation\"))\n    if deconv_activation == \"LeakyReLU\":\n        return LeakyReLU(0.2)\n    elif deconv_activation == \"ELU\":\n        return ELU()\n    return ReLU()\n\n\ndef apply_unet(\n    input_tensor: tf.Tensor,\n    output_name: str = \"output\",\n    params: Dict = {},\n    output_mask_logit: bool = False,\n) -> tf.Tensor:\n    \"\"\"\n    Apply a convolutionnal U-net to model a single instrument (one U-net\n    is used for each instrument).\n\n    Parameters:\n        input_tensor (tf.Tensor):\n            Input of the model.\n        output_name (str):\n            (Optional) name of the output, default to 'output'.\n        params (Dict):\n            (Optional) dict of BLSTM parameters.\n        output_mask_logit (bool):\n            (Optional) Sigmoid or logit?\n\n    Returns:\n        tf.Tensor:\n            Output tensor.\n    \"\"\"\n    logging.info(f\"Apply unet for {output_name}\")\n    conv_n_filters = params.get(\"conv_n_filters\", [16, 32, 64, 128, 256, 512])\n    conv_activation_layer = _get_conv_activation_layer(params)\n    deconv_activation_layer = _get_deconv_activation_layer(params)\n    kernel_initializer = he_uniform(seed=50)\n    conv2d_factory = partial(\n        Conv2D, strides=(2, 2), padding=\"same\", kernel_initializer=kernel_initializer\n    )\n    # First layer.\n    conv1 = conv2d_factory(conv_n_filters[0], (5, 5))(input_tensor)\n    batch1 = BatchNormalization(axis=-1)(conv1)\n    rel1 = conv_activation_layer(batch1)\n    # Second layer.\n    conv2 = conv2d_factory(conv_n_filters[1], (5, 5))(rel1)\n    batch2 = BatchNormalization(axis=-1)(conv2)\n    rel2 = conv_activation_layer(batch2)\n    # Third layer.\n    conv3 = conv2d_factory(conv_n_filters[2], (5, 5))(rel2)\n    batch3 = BatchNormalization(axis=-1)(conv3)\n    rel3 = conv_activation_layer(batch3)\n    # Fourth layer.\n    conv4 = conv2d_factory(conv_n_filters[3], (5, 5))(rel3)\n    batch4 = BatchNormalization(axis=-1)(conv4)\n    rel4 = conv_activation_layer(batch4)\n    # Fifth layer.\n    conv5 = conv2d_factory(conv_n_filters[4], (5, 5))(rel4)\n    batch5 = BatchNormalization(axis=-1)(conv5)\n    rel5 = conv_activation_layer(batch5)\n    # Sixth layer\n    conv6 = conv2d_factory(conv_n_filters[5], (5, 5))(rel5)\n    batch6 = BatchNormalization(axis=-1)(conv6)\n    _ = conv_activation_layer(batch6)\n    #\n    #\n    conv2d_transpose_factory = partial(\n        Conv2DTranspose,\n        strides=(2, 2),\n        padding=\"same\",\n        kernel_initializer=kernel_initializer,\n    )\n    #\n    up1 = conv2d_transpose_factory(conv_n_filters[4], (5, 5))((conv6))\n    up1 = deconv_activation_layer(up1)\n    batch7 = BatchNormalization(axis=-1)(up1)\n    drop1 = Dropout(0.5)(batch7)\n    merge1 = Concatenate(axis=-1)([conv5, drop1])\n    #\n    up2 = conv2d_transpose_factory(conv_n_filters[3], (5, 5))((merge1))\n    up2 = deconv_activation_layer(up2)\n    batch8 = BatchNormalization(axis=-1)(up2)\n    drop2 = Dropout(0.5)(batch8)\n    merge2 = Concatenate(axis=-1)([conv4, drop2])\n    #\n    up3 = conv2d_transpose_factory(conv_n_filters[2], (5, 5))((merge2))\n    up3 = deconv_activation_layer(up3)\n    batch9 = BatchNormalization(axis=-1)(up3)\n    drop3 = Dropout(0.5)(batch9)\n    merge3 = Concatenate(axis=-1)([conv3, drop3])\n    #\n    up4 = conv2d_transpose_factory(conv_n_filters[1], (5, 5))((merge3))\n    up4 = deconv_activation_layer(up4)\n    batch10 = BatchNormalization(axis=-1)(up4)\n    merge4 = Concatenate(axis=-1)([conv2, batch10])\n    #\n    up5 = conv2d_transpose_factory(conv_n_filters[0], (5, 5))((merge4))\n    up5 = deconv_activation_layer(up5)\n    batch11 = BatchNormalization(axis=-1)(up5)\n    merge5 = Concatenate(axis=-1)([conv1, batch11])\n    #\n    up6 = conv2d_transpose_factory(1, (5, 5), strides=(2, 2))((merge5))\n    up6 = deconv_activation_layer(up6)\n    batch12 = BatchNormalization(axis=-1)(up6)\n    # Last layer to ensure initial shape reconstruction.\n    if not output_mask_logit:\n        up7 = Conv2D(\n            2,\n            (4, 4),\n            dilation_rate=(2, 2),\n            activation=\"sigmoid\",\n            padding=\"same\",\n            kernel_initializer=kernel_initializer,\n        )((batch12))\n        output = Multiply(name=output_name)([up7, input_tensor])\n        return output\n    return Conv2D(\n        2,\n        (4, 4),\n        dilation_rate=(2, 2),\n        padding=\"same\",\n        kernel_initializer=kernel_initializer,\n    )((batch12))\n\n\ndef unet(\n    input_tensor: tf.Tensor, instruments: Iterable[str], params: Optional[Dict] = None\n) -> Dict:\n    \"\"\"Model function applier.\"\"\"\n    return apply(apply_unet, input_tensor, instruments, params)\n\n\ndef softmax_unet(\n    input_tensor: tf.Tensor, instruments: Iterable[str], params: Dict = {}\n) -> Dict:\n    \"\"\"\n    Apply softmax to multitrack unet in order to have mask suming to one.\n\n    Parameters:\n        input_tensor (tf.Tensor):\n            Tensor to apply blstm to.\n        instruments (Iterable[str]):\n            Iterable that provides a collection of instruments.\n        params (Dict):\n            (Optional) dict of BLSTM parameters.\n\n    Returns:\n        Dict:\n            Created output tensor dict.\n    \"\"\"\n    logit_mask_list = []\n    for instrument in instruments:\n        out_name = f\"{instrument}_spectrogram\"\n        logit_mask_list.append(\n            apply_unet(\n                input_tensor,\n                output_name=out_name,\n                params=params,\n                output_mask_logit=True,\n            )\n        )\n    masks = Softmax(axis=4)(tf.stack(logit_mask_list, axis=4))\n    output_dict = {}\n    for i, instrument in enumerate(instruments):\n        out_name = f\"{instrument}_spectrogram\"\n        output_dict[out_name] = Multiply(name=out_name)([masks[..., i], input_tensor])\n    return output_dict\n"
  },
  {
    "path": "spleeter/model/provider/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nThis package provides tools for downloading model from network\nusing remote storage abstraction.\n\nExample:\n```python\n>>> provider = MyProviderImplementation()\n>>> provider.get('/path/to/local/storage', params)\n```\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom os import environ, makedirs\nfrom os.path import exists, isabs, join, sep\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\nclass ModelProvider(ABC):\n    \"\"\"\n    A ModelProvider manages model files on disk and\n    file download is not available.\n    \"\"\"\n\n    DEFAULT_MODEL_PATH: str = environ.get(\"MODEL_PATH\", \"pretrained_models\")\n    MODEL_PROBE_PATH: str = \".probe\"\n\n    @abstractmethod\n    def download(_, name: str, path: str) -> None:\n        \"\"\"\n        Download model denoted by the given name to disk.\n\n        Parameters:\n            name (str):\n                Name of the model to download.\n            path (str):\n                Path of the directory to save model into.\n        \"\"\"\n        pass\n\n    @staticmethod\n    def writeProbe(directory: str) -> None:\n        \"\"\"\n        Write a model probe file into the given directory.\n\n        Parameters:\n            directory (str):\n                Directory to write probe into.\n        \"\"\"\n        probe: str = join(directory, ModelProvider.MODEL_PROBE_PATH)\n        with open(probe, \"w\") as stream:\n            stream.write(\"OK\")\n\n    def get(self, model_directory: str) -> str:\n        \"\"\"\n        Ensures required model is available at given location.\n\n        Parameters:\n            model_directory (str):\n                Expected model_directory to be available.\n\n        Raises:\n            IOError:\n                If model can not be retrieved.\n\n        Returns:\n            str:\n                Available model directory.\n        \"\"\"\n        # Expend model directory if needed.\n        if not isabs(model_directory):\n            model_directory = join(self.DEFAULT_MODEL_PATH, model_directory)\n        # Download it if not exists.\n        model_probe: str = join(model_directory, self.MODEL_PROBE_PATH)\n        if not exists(model_probe):\n            if not exists(model_directory):\n                makedirs(model_directory)\n                self.download(model_directory.split(sep)[-1], model_directory)\n                self.writeProbe(model_directory)\n        return model_directory\n\n    @classmethod\n    def default(_: type) -> \"ModelProvider\":\n        \"\"\"\n        Builds and returns a default model provider.\n\n        Returns:\n            ModelProvider:\n                A default model provider instance to use.\n        \"\"\"\n        from .github import GithubModelProvider\n\n        return GithubModelProvider.from_environ()\n"
  },
  {
    "path": "spleeter/model/provider/github.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nA ModelProvider backed by Github Release feature.\n\nExamples:\n\n```python\n>>> from spleeter.model.provider import github\n>>> provider = github.GithubModelProvider(\n        'github.com',\n        'Deezer/spleeter',\n        'latest')\n>>> provider.download('2stems', '/path/to/local/storage')\n```\n\"\"\"\n\nimport hashlib\nimport os\nimport tarfile\nfrom os import environ\nfrom tempfile import NamedTemporaryFile\nfrom typing import Dict\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport httpx\n\nfrom ...utils.logging import logger\nfrom . import ModelProvider\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef compute_file_checksum(path):\n    \"\"\"\n    Computes given path file sha256.\n\n    Parameters:\n        path (str):\n            Path of the file to compute checksum for.\n\n    Returns:\n        str:\n            File checksum.\n    \"\"\"\n    sha256 = hashlib.sha256()\n    with open(path, \"rb\") as stream:\n        for chunk in iter(lambda: stream.read(4096), b\"\"):\n            sha256.update(chunk)\n    return sha256.hexdigest()\n\n\nclass GithubModelProvider(ModelProvider):\n    \"\"\"A ModelProvider implementation backed on Github for remote storage.\"\"\"\n\n    DEFAULT_HOST: str = \"https://github.com\"\n    DEFAULT_REPOSITORY: str = \"deezer/spleeter\"\n\n    CHECKSUM_INDEX: str = \"checksum.json\"\n    LATEST_RELEASE: str = \"v1.4.0\"\n    RELEASE_PATH: str = \"releases/download\"\n\n    def __init__(self, host: str, repository: str, release: str) -> None:\n        \"\"\"Default constructor.\n\n        Parameters:\n            host (str):\n                Host to the Github instance to reach.\n            repository (str):\n                Repository path within target Github.\n            release (str):\n                Release name to get models from.\n        \"\"\"\n        self._host: str = host\n        self._repository: str = repository\n        self._release: str = release\n\n    @classmethod\n    def from_environ(cls) -> \"GithubModelProvider\":\n        \"\"\"\n        Factory method that creates provider from envvars.\n\n        Returns:\n            GithubModelProvider:\n                Created instance.\n        \"\"\"\n        return cls(\n            environ.get(\"GITHUB_HOST\", cls.DEFAULT_HOST),\n            environ.get(\"GITHUB_REPOSITORY\", cls.DEFAULT_REPOSITORY),\n            environ.get(\"GITHUB_RELEASE\", cls.LATEST_RELEASE),\n        )\n\n    def checksum(self, name: str) -> str:\n        \"\"\"\n        Downloads and returns reference checksum for the given model name.\n\n        Parameters:\n            name (str):\n                Name of the model to get checksum for.\n\n        Returns:\n            str:\n                Checksum of the required model.\n\n        Raises:\n            ValueError:\n                If the given model name is not indexed.\n        \"\"\"\n        url: str = \"/\".join(\n            (\n                self._host,\n                self._repository,\n                self.RELEASE_PATH,\n                self._release,\n                self.CHECKSUM_INDEX,\n            )\n        )\n        response: httpx.Response = httpx.get(url)\n        response.raise_for_status()\n        index: Dict = response.json()\n        if name not in index:\n            raise ValueError(f\"No checksum for model {name}\")\n        return index[name]\n\n    def download(self, name: str, path: str) -> None:\n        \"\"\"\n        Download model denoted by the given name to disk.\n\n        Parameters:\n            name (str):\n                Name of the model to download.\n            path (str):\n                Path of the directory to save model into.\n        \"\"\"\n        url: str = \"/\".join(\n            (self._host, self._repository, self.RELEASE_PATH, self._release, name)\n        )\n        url = f\"{url}.tar.gz\"\n        logger.info(f\"Downloading model archive {url}\")\n        with httpx.Client(http2=True) as client:\n            with client.stream(\"GET\", url) as response:\n                response.raise_for_status()\n                archive = NamedTemporaryFile(delete=False)\n                try:\n                    with archive as stream:\n                        for chunk in response.iter_raw():\n                            stream.write(chunk)\n                    logger.info(\"Validating archive checksum\")\n                    checksum: str = compute_file_checksum(archive.name)\n                    if checksum != self.checksum(name):\n                        raise IOError(\"Downloaded file is corrupted, please retry\")\n                    logger.info(f\"Extracting downloaded {name} archive\")\n                    with tarfile.open(name=archive.name) as tar:\n                        tar.extractall(path=path)\n                finally:\n                    os.unlink(archive.name)\n        logger.info(f\"{name} model file(s) extracted\")\n"
  },
  {
    "path": "spleeter/options.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"This modules provides spleeter command as well as CLI parsing methods.\"\"\"\n\nfrom os.path import join\nfrom tempfile import gettempdir\n\nfrom typer import Argument, Exit, Option, echo\nfrom typer.models import List, Optional\n\nfrom .audio import Codec\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nAudioInputArgument: List[str] = Argument(\n    ...,\n    help=\"List of input audio file path\",\n    exists=True,\n    file_okay=True,\n    dir_okay=False,\n    readable=True,\n    resolve_path=True,\n)\n\nAudioInputOption: Optional[str] = Option(\n    None, \"--inputs\", \"-i\", help=\"(DEPRECATED) placeholder for deprecated input option\"\n)\n\nAudioAdapterOption: str = Option(\n    \"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter\",\n    \"--adapter\",\n    \"-a\",\n    help=\"Name of the audio adapter to use for audio I/O\",\n)\n\nAudioOutputOption: str = Option(\n    join(gettempdir(), \"separated_audio\"),\n    \"--output_path\",\n    \"-o\",\n    help=\"Path of the output directory to write audio files in\",\n)\n\nAudioOffsetOption: float = Option(\n    0.0, \"--offset\", \"-s\", help=\"Set the starting offset to separate audio from\"\n)\n\nAudioDurationOption: float = Option(\n    600.0,\n    \"--duration\",\n    \"-d\",\n    help=(\n        \"Set a maximum duration for processing audio \"\n        \"(only separate offset + duration first seconds of \"\n        \"the input file)\"\n    ),\n)\n\nAudioCodecOption: Codec = Option(\n    Codec.WAV, \"--codec\", \"-c\", help=\"Audio codec to be used for the separated output\"\n)\n\nAudioBitrateOption: str = Option(\n    \"128k\", \"--bitrate\", \"-b\", help=\"Audio bitrate to be used for the separated output\"\n)\n\nFilenameFormatOption: str = Option(\n    \"{filename}/{instrument}.{codec}\",\n    \"--filename_format\",\n    \"-f\",\n    help=(\n        \"Template string that will be formatted to generated\"\n        \"output filename. Such template should be Python formattable\"\n        \"string, and could use {filename}, {instrument}, and {codec}\"\n        \"variables\"\n    ),\n)\n\nModelParametersOption: str = Option(\n    \"spleeter:2stems\",\n    \"--params_filename\",\n    \"-p\",\n    help=\"JSON filename that contains params\",\n)\n\n\nMWFOption: bool = Option(\n    False, \"--mwf\", help=\"Whether to use multichannel Wiener filtering for separation\"\n)\n\nMUSDBDirectoryOption: str = Option(\n    ...,\n    \"--mus_dir\",\n    exists=True,\n    dir_okay=True,\n    file_okay=False,\n    readable=True,\n    resolve_path=True,\n    help=\"Path to musDB dataset directory\",\n)\n\nTrainingDataDirectoryOption: str = Option(\n    ...,\n    \"--data\",\n    \"-d\",\n    exists=True,\n    dir_okay=True,\n    file_okay=False,\n    readable=True,\n    resolve_path=True,\n    help=\"Path of the folder containing audio data for training\",\n)\n\nVerboseOption: bool = Option(False, \"--verbose\", help=\"Enable verbose logs\")\n\n\ndef version_callback(value: bool):\n    if value:\n        from importlib.metadata import version\n\n        echo(f\"Spleeter Version: {version('spleeter')}\")\n        raise Exit()\n\n\nVersionOption: bool = Option(\n    None,\n    \"--version\",\n    callback=version_callback,\n    is_eager=True,\n    help=\"Return Spleeter version\",\n)\n"
  },
  {
    "path": "spleeter/py.typed",
    "content": ""
  },
  {
    "path": "spleeter/resources/2stems-16kHz.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"2stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"accompaniment\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1536,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 1000000,\n    \"throttle_secs\":300,\n    \"random_seed\":0,\n    \"save_checkpoints_steps\":150,\n    \"save_summary_steps\":5,\n    \"model\":{\n            \"type\":\"unet.unet\",\n            \"params\":{}\n            }\n}\n"
  },
  {
    "path": "spleeter/resources/2stems.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"2stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"accompaniment\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 1000000,\n    \"throttle_secs\":300,\n    \"random_seed\":0,\n    \"save_checkpoints_steps\":150,\n    \"save_summary_steps\":5,\n    \"model\":{\n            \"type\":\"unet.unet\",\n            \"params\":{}\n            }\n}\n"
  },
  {
    "path": "spleeter/resources/4stems-16kHz.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/val.csv\",\n    \"model_dir\": \"4stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1536,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 1500000,\n    \"throttle_secs\":600,\n    \"random_seed\":3,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "spleeter/resources/4stems.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/val.csv\",\n    \"model_dir\": \"4stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 1500000,\n    \"throttle_secs\":600,\n    \"random_seed\":3,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "spleeter/resources/5stems-16kHz.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"5stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"piano\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1536,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 2500000,\n    \"throttle_secs\":600,\n    \"random_seed\":8,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.softmax_unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "spleeter/resources/5stems.json",
    "content": "{\n    \"train_csv\": \"path/to/train.csv\",\n    \"validation_csv\": \"path/to/test.csv\",\n    \"model_dir\": \"5stems\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"piano\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 2500000,\n    \"throttle_secs\":600,\n    \"random_seed\":8,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.softmax_unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "spleeter/resources/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Packages that provides static resources file for the library. \"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n"
  },
  {
    "path": "spleeter/resources/musdb.json",
    "content": "{\n    \"train_csv\": \"configs/musdb_train.csv\",\n    \"validation_csv\": \"configs/musdb_validation.csv\",\n    \"model_dir\": \"musdb_model\",\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"drums\", \"bass\", \"other\"],\n    \"sample_rate\":44100,\n    \"frame_length\":4096,\n    \"frame_step\":1024,\n    \"T\":512,\n    \"F\":1024,\n    \"n_channels\":2,\n    \"n_chunks_per_song\":1,\n    \"separation_exponent\":2,\n    \"mask_extension\":\"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\":4,\n    \"training_cache\":\"training_cache\",\n    \"validation_cache\":\"validation_cache\",\n    \"train_max_steps\": 100000,\n    \"throttle_secs\":600,\n    \"random_seed\":3,\n    \"save_checkpoints_steps\":300,\n    \"save_summary_steps\":5,\n    \"model\":{\n        \"type\":\"unet.unet\",\n        \"params\":{\n               \"conv_activation\":\"ELU\",\n               \"deconv_activation\":\"ELU\"\n        }\n    }\n}\n"
  },
  {
    "path": "spleeter/separator.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"\nModule that provides a class wrapper for source separation.\n\nExamples:\n\n```python\n>>> from spleeter.separator import Separator\n>>> separator = Separator('spleeter:2stems')\n>>> separator.separate(waveform, lambda instrument, data: ...)\n>>> separator.separate_to_file(...)\n```\n\"\"\"\n\nimport atexit\nimport os\nfrom multiprocessing import Pool\nfrom os.path import basename, dirname, join, splitext\nfrom typing import Any, Dict, Generator, List, Optional\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport numpy as np\nimport tensorflow as tf  # type: ignore\n\nfrom . import SpleeterError\nfrom .audio import Codec\nfrom .audio.adapter import AudioAdapter\nfrom .audio.convertor import to_stereo\nfrom .model import EstimatorSpecBuilder, InputProviderFactory, model_fn\nfrom .model.provider import ModelProvider\nfrom .types import AudioDescriptor\nfrom .utils.configuration import load_configuration\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef create_estimator(params: Dict, MWF: bool) -> tf.Tensor:\n    \"\"\"\n    Initialize tensorflow estimator that will perform separation\n\n    Parameters:\n        params (Dict):\n            A dictionary of parameters for building the model\n        MWF (bool):\n            Wiener filter enabled?\n\n    Returns:\n        tf.Tensor:\n            A tensorflow estimator\n    \"\"\"\n    # Load model.\n    provider: ModelProvider = ModelProvider.default()\n    params[\"model_dir\"] = provider.get(params[\"model_dir\"])\n    params[\"MWF\"] = MWF\n    # Setup config\n    session_config = tf.compat.v1.ConfigProto()\n    session_config.gpu_options.per_process_gpu_memory_fraction = 0.7\n    config = tf.estimator.RunConfig(session_config=session_config)\n    # Setup estimator\n    estimator = tf.estimator.Estimator(\n        model_fn=model_fn, model_dir=params[\"model_dir\"], params=params, config=config\n    )\n    return estimator\n\n\nclass Separator(object):\n    \"\"\"A wrapper class for performing separation.\"\"\"\n\n    def __init__(\n        self,\n        params_descriptor: str,\n        MWF: bool = False,\n        multiprocess: bool = True,\n    ) -> None:\n        \"\"\"\n        Default constructor.\n\n        Parameters:\n            params_descriptor (str):\n                Descriptor for TF params to be used.\n            MWF (bool):\n                (Optional) `True` if MWF should be used, `False` otherwise.\n            multiprocess (bool):\n                (Optional) Enable multi-processing.\n        \"\"\"\n        self._params = load_configuration(params_descriptor)\n        self._sample_rate = self._params[\"sample_rate\"]\n        self._MWF = MWF\n        self._tf_graph = tf.Graph()\n        self._prediction_generator: Optional[Generator] = None\n        self._input_provider = None\n        self._builder = None\n        self._features = None\n        self._session = None\n        if multiprocess:\n            self._pool: Optional[Any] = Pool()\n            atexit.register(self._pool.close)\n        else:\n            self._pool = None\n        self._tasks: List = []\n        self.estimator = None\n\n    def _get_prediction_generator(self, data: dict) -> Generator:\n        \"\"\"\n        Lazy loading access method for internal prediction generator\n        returned by the predict method of a tensorflow estimator.\n\n        Returns:\n            Generator:\n                Generator of prediction.\n        \"\"\"\n        if not self.estimator:\n            self.estimator = create_estimator(self._params, self._MWF)\n\n        def get_dataset():\n            return tf.data.Dataset.from_tensors(data)\n\n        return self.estimator.predict(get_dataset, yield_single_examples=False)\n\n    def join(self, timeout: int = 200) -> None:\n        \"\"\"\n        Wait for all pending tasks to be finished.\n\n        Parameters:\n            timeout (int):\n                (Optional) Task waiting timeout.\n        \"\"\"\n        while len(self._tasks) > 0:\n            task = self._tasks.pop()\n            task.get()\n            task.wait(timeout=timeout)\n\n    def _get_input_provider(self):\n        if self._input_provider is None:\n            self._input_provider = InputProviderFactory.get(self._params)\n        return self._input_provider\n\n    def _get_features(self):\n        if self._features is None:\n            provider = self._get_input_provider()\n            self._features = provider.get_input_dict_placeholders()\n        return self._features\n\n    def _get_builder(self):\n        if self._builder is None:\n            self._builder = EstimatorSpecBuilder(self._get_features(), self._params)\n        return self._builder\n\n    def _get_session(self):\n        if self._session is None:\n            saver = tf.compat.v1.train.Saver()\n            provider = ModelProvider.default()\n            model_directory: str = provider.get(self._params[\"model_dir\"])\n            latest_checkpoint = tf.train.latest_checkpoint(model_directory)\n            self._session = tf.compat.v1.Session()\n            saver.restore(self._session, latest_checkpoint)\n        return self._session\n\n    def _separate_tensorflow(\n        self, waveform: np.ndarray, audio_descriptor: AudioDescriptor\n    ) -> Dict:\n        \"\"\"\n        Performs source separation over the given waveform with tensorflow\n        backend.\n\n        Parameters:\n            waveform (np.ndarray):\n                Waveform to be separated (as a numpy array)\n            audio_descriptor (AudioDescriptor):\n                Audio descriptor to be used.\n\n        Returns:\n            Dict:\n                Separated waveforms.\n        \"\"\"\n        if not waveform.shape[-1] == 2:\n            waveform = to_stereo(waveform)\n        prediction_generator = self._get_prediction_generator(\n            {\"waveform\": waveform, \"audio_id\": np.array(audio_descriptor)}\n        )\n        # NOTE: perform separation.\n        prediction = next(prediction_generator)\n        prediction.pop(\"audio_id\")\n        return prediction\n\n    def separate(\n        self, waveform: np.ndarray, audio_descriptor: Optional[str] = \"\"\n    ) -> Dict:\n        \"\"\"\n        Performs separation on a waveform.\n\n        Parameters:\n            waveform (np.ndarray):\n                Waveform to be separated (as a numpy array)\n            audio_descriptor (Optional[str]):\n                (Optional) string describing the waveform (e.g. filename).\n\n        Returns:\n            Dict:\n                Separated waveforms.\n        \"\"\"\n        return self._separate_tensorflow(waveform, audio_descriptor)\n\n    def separate_to_file(\n        self,\n        audio_descriptor: AudioDescriptor,\n        destination: str,\n        audio_adapter: Optional[AudioAdapter] = None,\n        offset: float = 0,\n        duration: float = 600.0,\n        codec: Codec = Codec.WAV,\n        bitrate: str = \"128k\",\n        filename_format: str = \"{filename}/{instrument}.{codec}\",\n        synchronous: bool = True,\n    ) -> None:\n        \"\"\"\n        Performs source separation and export result to file using\n        given audio adapter.\n\n        Filename format should be a Python formattable string that could\n        use following parameters :\n\n        - {instrument}\n        - {filename}\n        - {foldername}\n        - {codec}.\n\n        Parameters:\n            audio_descriptor (AudioDescriptor):\n                Describe song to separate, used by audio adapter to\n                retrieve and load audio data, in case of file based\n                audio adapter, such descriptor would be a file path.\n            destination (str):\n                Target directory to write output to.\n            audio_adapter (AudioAdapter):\n                (Optional) Audio adapter to use for I/O.\n            offset (int):\n                (Optional) Offset of loaded song.\n            duration (float):\n                (Optional) Duration of loaded song (default: 600s).\n            codec (Codec):\n                (Optional) Export codec.\n            bitrate (str):\n                (Optional) Export bitrate.\n            filename_format (str):\n                (Optional) Filename format.\n            synchronous (bool):\n                (Optional) True is should by synchronous.\n        \"\"\"\n        if audio_adapter is None:\n            audio_adapter = AudioAdapter.default()\n        waveform, _ = audio_adapter.load(\n            audio_descriptor,\n            offset=offset,\n            duration=duration,\n            sample_rate=self._sample_rate,\n        )\n        sources = self.separate(waveform, audio_descriptor)\n        self.save_to_file(\n            sources,\n            audio_descriptor,\n            destination,\n            filename_format,\n            codec,\n            audio_adapter,\n            bitrate,\n            synchronous,\n        )\n\n    def save_to_file(\n        self,\n        sources: Dict,\n        audio_descriptor: AudioDescriptor,\n        destination: str,\n        filename_format: str = \"{filename}/{instrument}.{codec}\",\n        codec: Codec = Codec.WAV,\n        audio_adapter: Optional[AudioAdapter] = None,\n        bitrate: str = \"128k\",\n        synchronous: bool = True,\n    ) -> None:\n        \"\"\"\n        Export dictionary of sources to files.\n\n        Parameters:\n            sources (Dict):\n                Dictionary of sources to be exported. The keys are the name\n                of the instruments, and the values are `N x 2` numpy arrays\n                containing the corresponding intrument waveform, as\n                returned by the separate method\n            audio_descriptor (AudioDescriptor):\n                Describe song to separate, used by audio adapter to\n                retrieve and load audio data, in case of file based audio\n                adapter, such descriptor would be a file path.\n            destination (str):\n                Target directory to write output to.\n            filename_format (str):\n                (Optional) Filename format.\n            codec (Codec):\n                (Optional) Export codec.\n            audio_adapter (Optional[AudioAdapter]):\n                (Optional) Audio adapter to use for I/O.\n            bitrate (str):\n                (Optional) Export bitrate.\n            synchronous (bool):\n                (Optional) True is should by synchronous.\n        \"\"\"\n        if audio_adapter is None:\n            audio_adapter = AudioAdapter.default()\n        foldername = basename(dirname(audio_descriptor))\n        filename = splitext(basename(audio_descriptor))[0]\n        generated = []\n        for instrument, data in sources.items():\n            path = join(\n                destination,\n                filename_format.format(\n                    filename=filename,\n                    instrument=instrument,\n                    foldername=foldername,\n                    codec=codec,\n                ),\n            )\n            directory = os.path.dirname(path)\n            if not os.path.exists(directory):\n                os.makedirs(directory)\n            if path in generated:\n                raise SpleeterError(\n                    (\n                        f\"Separated source path conflict : {path},\"\n                        \"please check your filename format\"\n                    )\n                )\n            generated.append(path)\n            if self._pool:\n                task = self._pool.apply_async(\n                    audio_adapter.save, (path, data, self._sample_rate, codec, bitrate)\n                )\n                self._tasks.append(task)\n            else:\n                audio_adapter.save(path, data, self._sample_rate, codec, bitrate)\n        if synchronous and self._pool:\n            self.join()\n"
  },
  {
    "path": "spleeter/types.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Custom types definition. \"\"\"\n\nfrom typing import Any, Tuple\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport numpy as np\n\n# pylint: enable=import-error\n\nAudioDescriptor = Any\nSignal = Tuple[np.ndarray, float]\n"
  },
  {
    "path": "spleeter/utils/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" This package provides utility function and classes. \"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n"
  },
  {
    "path": "spleeter/utils/configuration.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Module that provides configuration loading function. \"\"\"\n\nimport importlib.resources as loader\nimport json\nfrom os.path import exists\nfrom typing import Dict\n\nfrom .. import SpleeterError, resources\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n_EMBEDDED_CONFIGURATION_PREFIX: str = \"spleeter:\"\n\n\ndef load_configuration(descriptor: str) -> Dict:\n    \"\"\"\n    Load configuration from the given descriptor.\n    Could be either a `spleeter:` prefixed embedded configuration name\n    or a file system path to read configuration from.\n\n    Parameters:\n        descriptor (str):\n            Configuration descriptor to use for lookup.\n\n    Returns:\n        Dict:\n            Loaded description as dict.\n\n    Raises:\n        ValueError:\n            If required embedded configuration does not exists.\n        SpleeterError:\n            If required configuration file does not exists.\n    \"\"\"\n    # Embedded configuration reading.\n    if descriptor.startswith(_EMBEDDED_CONFIGURATION_PREFIX):\n        name = descriptor[len(_EMBEDDED_CONFIGURATION_PREFIX) :]\n        if not loader.is_resource(resources, f\"{name}.json\"):\n            raise SpleeterError(f\"No embedded configuration {name} found\")\n        with loader.open_text(resources, f\"{name}.json\") as stream:\n            return json.load(stream)\n    # Standard file reading.\n    if not exists(descriptor):\n        raise SpleeterError(f\"Configuration file {descriptor} not found\")\n    with open(descriptor, \"r\") as stream:\n        return json.load(stream)\n"
  },
  {
    "path": "spleeter/utils/logging.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\"Centralized logging facilities for Spleeter.\"\"\"\n\nimport logging\nimport warnings\nfrom os import environ\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nfrom typer import echo\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nenviron[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n\n\nclass TyperLoggerHandler(logging.Handler):\n    \"\"\"A custom logger handler that use Typer echo.\"\"\"\n\n    def emit(self, record: logging.LogRecord) -> None:\n        echo(self.format(record))\n\n\nformatter = logging.Formatter(\"%(levelname)s:%(name)s:%(message)s\")\nhandler = TyperLoggerHandler()\nhandler.setFormatter(formatter)\nlogger: logging.Logger = logging.getLogger(\"spleeter\")\nlogger.addHandler(handler)\nlogger.setLevel(logging.INFO)\n\n\ndef configure_logger(verbose: bool) -> None:\n    \"\"\"\n    Configure application logger.\n\n    Parameters:\n        verbose (bool):\n            `True` to use verbose logger, `False` otherwise.\n    \"\"\"\n    from tensorflow import get_logger  # type: ignore\n    from tensorflow.compat.v1 import logging as tf_logging  # type: ignore\n\n    tf_logger = get_logger()\n    tf_logger.handlers = [handler]\n    if verbose:\n        tf_logging.set_verbosity(tf_logging.INFO)\n        logger.setLevel(logging.DEBUG)\n    else:\n        warnings.filterwarnings(\"ignore\")\n        tf_logging.set_verbosity(tf_logging.ERROR)\n"
  },
  {
    "path": "spleeter/utils/tensor.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Utility function for tensorflow. \"\"\"\n\nfrom typing import Any, Callable, Dict\n\nimport pandas as pd  # type: ignore\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nimport tensorflow as tf  # type: ignore\n\n# pylint: enable=import-error\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\n\ndef sync_apply(\n    tensor_dict: Dict[str, tf.Tensor], func: Callable, concat_axis: int = 1\n) -> Dict[str, tf.Tensor]:\n    \"\"\"\n    Return a function that applies synchronously the provided func on the\n    provided dictionnary of tensor. This means that func is applied to the\n    concatenation of the tensors in tensor_dict. This is useful for\n    performing random operation that needs the same drawn value on multiple\n    tensor, such as a random time-crop on both input data and label (the\n    same crop should be applied to both input data and label, so random\n    crop cannot be applied separately on each of them).\n\n    Note:\n        All tensor are assumed to be the same shape.\n\n    Parameters:\n        tensor_dict (Dict[str, tf.Tensor]):\n            A dictionary of tensor.\n        func (Callable):\n            Function to be applied to the concatenation of the tensors in\n            `tensor_dict`.\n        concat_axis (int):\n            (Optional) The axis on which to perform the concatenation.\n\n    Returns:\n        Dict[str, tf.Tensor]:\n            Processed tensors dictionary with the same name (keys) as input\n            tensor_dict.\n    \"\"\"\n    if concat_axis not in {0, 1}:\n        raise NotImplementedError(\n            \"Function only implemented for concat_axis equal to 0 or 1\"\n        )\n    tensor_list = list(tensor_dict.values())\n    concat_tensor = tf.concat(tensor_list, concat_axis)\n    processed_concat_tensor = func(concat_tensor)\n    tensor_shape = tf.shape(list(tensor_dict.values())[0])\n    D = tensor_shape[concat_axis]\n    if concat_axis == 0:\n        return {\n            name: processed_concat_tensor[index * D : (index + 1) * D, :, :]\n            for index, name in enumerate(tensor_dict)\n        }\n    return {\n        name: processed_concat_tensor[:, index * D : (index + 1) * D, :]\n        for index, name in enumerate(tensor_dict)\n    }\n\n\ndef from_float32_to_uint8(\n    tensor: tf.Tensor,\n    tensor_key: str = \"tensor\",\n    min_key: str = \"min\",\n    max_key: str = \"max\",\n) -> tf.Tensor:\n    tensor_min = tf.reduce_min(tensor)\n    tensor_max = tf.reduce_max(tensor)\n    return {\n        tensor_key: tf.cast(\n            (tensor - tensor_min) / (tensor_max - tensor_min + 1e-16) * 255.9999,\n            dtype=tf.uint8,\n        ),\n        min_key: tensor_min,\n        max_key: tensor_max,\n    }\n\n\ndef from_uint8_to_float32(\n    tensor: tf.Tensor, tensor_min: tf.Tensor, tensor_max: tf.Tensor\n) -> tf.Tensor:\n    return (\n        tf.cast(tensor, tf.float32) * (tensor_max - tensor_min) / 255.9999 + tensor_min\n    )\n\n\ndef pad_and_partition(tensor: tf.Tensor, segment_len: int) -> tf.Tensor:\n    \"\"\"\n    Pad and partition a tensor into segment of len `segment_len`\n    along the first dimension. The tensor is padded with 0 in order\n    to ensure that the first dimension is a multiple of `segment_len`.\n\n    Examples:\n    ```python\n    >>> tensor = [[1, 2, 3], [4, 5, 6]]\n    >>> segment_len = 2\n    >>> pad_and_partition(tensor, segment_len)\n    [[[1, 2], [4, 5]], [[3, 0], [6, 0]]]\n    ````\n\n    Parameters:\n        tensor (tf.Tensor):\n            Tensor of known fixed rank\n        segment_len (int):\n            Segment length.\n\n    Returns:\n        tf.Tensor:\n            Padded and partitioned tensor.\n    \"\"\"\n    tensor_size = tf.math.floormod(tf.shape(tensor)[0], segment_len)\n    pad_size = tf.math.floormod(segment_len - tensor_size, segment_len)\n    padded = tf.pad(tensor, [[0, pad_size]] + [[0, 0]] * (len(tensor.shape) - 1))\n    split = (tf.shape(padded)[0] + segment_len - 1) // segment_len\n    return tf.reshape(\n        padded, tf.concat([[split, segment_len], tf.shape(padded)[1:]], axis=0)\n    )\n\n\ndef pad_and_reshape(instr_spec, frame_length, F) -> Any:\n    spec_shape = tf.shape(instr_spec)\n    extension_row = tf.zeros((spec_shape[0], spec_shape[1], 1, spec_shape[-1]))\n    n_extra_row = (frame_length) // 2 + 1 - F\n    extension = tf.tile(extension_row, [1, 1, n_extra_row, 1])\n    extended_spec = tf.concat([instr_spec, extension], axis=2)\n    old_shape = tf.shape(extended_spec)\n    new_shape = tf.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], axis=0)\n    processed_instr_spec = tf.reshape(extended_spec, new_shape)\n    return processed_instr_spec\n\n\ndef dataset_from_csv(csv_path: str, **kwargs) -> Any:\n    \"\"\"\n    Load dataset from a CSV file using Pandas.\n    kwargs if any are forwarded to the `pandas.read_csv` function.\n\n    Parameters:\n        csv_path (str):\n            Path of the CSV file to load dataset from.\n\n    Returns:\n        Any:\n            Loaded dataset.\n    \"\"\"\n    df = pd.read_csv(csv_path, **kwargs)\n    dataset = tf.data.Dataset.from_tensor_slices({key: df[key].values for key in df})\n    return dataset\n\n\ndef check_tensor_shape(tensor_tf: tf.Tensor, target_shape: Any) -> bool:\n    \"\"\"\n    Return a Tensorflow boolean graph that indicates whether\n    sample[features_key] has the specified target shape.\n    Only check not None entries of target_shape.\n\n    Parameters:\n        tensor_tf (tensorflow.Tensor):\n            Tensor to check shape for.\n        target_shape (Any):\n            Target shape to compare tensor to.\n\n    Returns:\n        bool:\n            `True` if shape is valid, `False` otherwise (as TF boolean).\n    \"\"\"\n    result = tf.constant(True)\n    for i, target_length in enumerate(target_shape):\n        if target_length:\n            result = tf.logical_and(\n                result, tf.equal(tf.constant(target_length), tf.shape(tensor_tf)[i])\n            )\n    return result\n\n\ndef set_tensor_shape(tensor: tf.Tensor, tensor_shape: Any) -> tf.Tensor:\n    \"\"\"\n    Set shape for a tensor (not in place, as opposed to tf.set_shape).\n\n    Parameters:\n        tensor (tensorflow.Tensor):\n            Tensor to reshape.\n        tensor_shape (Any):\n            Shape to apply to the tensor.\n\n    Returns:\n        tensorflow.Tensor:\n            A reshaped tensor.\n    \"\"\"\n    tensor.set_shape(tensor_shape)\n    return tensor\n"
  },
  {
    "path": "spleeter.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"K6mcSc0mmp3i\"\n   },\n   \"source\": [\n    \"# Install spleeter\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 109\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"f8Brdfh6mzEz\",\n    \"outputId\": \"c63dae8e-1d33-48f2-879f-dd15393a5034\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!apt install ffmpeg\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 1000\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"V_6Ram1lmc1F\",\n    \"outputId\": \"26a8df7b-6b6c-41e7-d874-acea0247d181\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"pip install spleeter\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {},\n    \"colab_type\": \"code\",\n    \"id\": \"W0LktyMypXqE\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from IPython.display import Audio\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"colab_type\": \"text\",\n    \"id\": \"afbcUSken16L\"\n   },\n   \"source\": [\n    \"# Separate from command line\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 311\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"O1kQaoJSoAD0\",\n    \"outputId\": \"cd1868b4-6992-47c3-8a2b-920e6f288614\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!wget https://github.com/deezer/spleeter/raw/master/audio_example.mp3\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 60\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"ibG6uF55p4lH\",\n    \"outputId\": \"f2785922-0ee1-4769-807a-6ee69313993c\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"Audio('audio_example.mp3')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 660\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"kOAqBcPhn6IU\",\n    \"outputId\": \"23e14ad5-209d-4ed6-b909-7c0cd966bd0c\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!spleeter separate -h\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 533\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"dGL-k5xxoKbu\",\n    \"outputId\": \"dd8d6a7f-515c-47f0-8388-39e179ef652a\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!spleeter separate -o output/ audio_example.mp3\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 63\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"IDuPWcAMoZP_\",\n    \"outputId\": \"3f9a05fd-afab-41c7-d47c-433fc614283b\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!ls output/audio_example\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 60\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"e7CHpyiloxrk\",\n    \"outputId\": \"d1ff17ac-8cef-4b9d-913a-01c2688ffef1\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"Audio('output/audio_example/vocals.wav')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 60\n    },\n    \"colab_type\": \"code\",\n    \"id\": \"ibXd-WCTpT0w\",\n    \"outputId\": \"6716708d-1cdb-4be5-da22-593075de78ca\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"Audio('output/audio_example/accompaniment.wav')\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"colab\": {\n   \"name\": \"spleeter.ipynb\",\n   \"provenance\": []\n  },\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"name\": \"python3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Unit testing package. \"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n"
  },
  {
    "path": "tests/test_command.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Unit testing for Separator class. \"\"\"\n\n__email__ = \"research@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nfrom typer.testing import CliRunner\n\nfrom spleeter.__main__ import spleeter\n\n\ndef test_version():\n\n    runner = CliRunner()\n\n    # execute spleeter version command\n    _ = runner.invoke(\n        spleeter,\n        [\n            \"--version\",\n        ],\n    )\n"
  },
  {
    "path": "tests/test_eval.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Unit testing for Separator class. \"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nfrom os import makedirs\nfrom os.path import join\nfrom tempfile import TemporaryDirectory\n\nimport numpy as np\n\nfrom spleeter.__main__ import evaluate\nfrom spleeter.audio.adapter import AudioAdapter\n\nres_4stems = {\n    \"vocals\": {\"SDR\": 3.25e-05, \"SAR\": -11.153575, \"SIR\": -1.3849, \"ISR\": 2.75e-05},\n    \"drums\": {\"SDR\": -0.079505, \"SAR\": -15.7073575, \"SIR\": -4.972755, \"ISR\": 0.0013575},\n    \"bass\": {\"SDR\": 2.5e-06, \"SAR\": -10.3520575, \"SIR\": -4.272325, \"ISR\": 2.5e-06},\n    \"other\": {\"SDR\": -1.359175, \"SAR\": -14.7076775, \"SIR\": -4.761505, \"ISR\": -0.01528},\n}\n\n\ndef generate_fake_eval_dataset(path):\n    \"\"\"\n    Generate fake evaluation dataset\n    \"\"\"\n    aa = AudioAdapter.default()\n    n_songs = 2\n    fs = 44100\n    duration = 3\n    n_channels = 2\n    rng = np.random.RandomState(seed=0)\n    for song in range(n_songs):\n        song_path = join(path, \"test\", f\"song{song}\")\n        makedirs(song_path, exist_ok=True)\n        for instr in [\"mixture\", \"vocals\", \"bass\", \"drums\", \"other\"]:\n            filename = join(song_path, f\"{instr}.wav\")\n            data = rng.rand(duration * fs, n_channels) - 0.5\n            aa.save(filename, data, fs)\n\n\ndef test_evaluate():\n    with TemporaryDirectory() as dataset:\n        with TemporaryDirectory() as evaluation:\n            generate_fake_eval_dataset(dataset)\n            metrics = evaluate(\n                adapter=\"spleeter.audio.ffmpeg.FFMPEGProcessAudioAdapter\",\n                output_path=evaluation,\n                params_filename=\"spleeter:4stems\",\n                mus_dir=dataset,\n                mwf=False,\n                verbose=False,\n            )\n            for instrument, metric in metrics.items():\n                for m, value in metric.items():\n                    assert np.allclose(\n                        np.median(value), res_4stems[instrument][m], atol=1e-3\n                    )\n"
  },
  {
    "path": "tests/test_ffmpeg_adapter.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Unit testing for audio adapter. \"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nfrom os.path import join\nfrom tempfile import TemporaryDirectory\n\nimport ffmpeg  # type: ignore\nimport numpy as np\n\n# pyright: reportMissingImports=false\n# pylint: disable=import-error\nfrom pytest import fixture, raises\n\nfrom spleeter import SpleeterError\nfrom spleeter.audio.adapter import AudioAdapter\nfrom spleeter.audio.ffmpeg import FFMPEGProcessAudioAdapter\n\n# pylint: enable=import-error\n\nTEST_AUDIO_DESCRIPTOR = \"audio_example.mp3\"\nTEST_OFFSET = 0\nTEST_DURATION = 600.0\nTEST_SAMPLE_RATE = 44100\n\n\n@fixture(scope=\"session\")\ndef adapter():\n    \"\"\"Target test audio adapter fixture.\"\"\"\n    return AudioAdapter.default()\n\n\n@fixture(scope=\"session\")\ndef audio_data(adapter):\n    \"\"\"Audio data fixture based on sample loading from adapter.\"\"\"\n    return adapter.load(\n        TEST_AUDIO_DESCRIPTOR, TEST_OFFSET, TEST_DURATION, TEST_SAMPLE_RATE\n    )\n\n\ndef test_default_adapter(adapter):\n    \"\"\"Test adapter as default adapter.\"\"\"\n    assert isinstance(adapter, FFMPEGProcessAudioAdapter)\n    assert adapter is AudioAdapter._DEFAULT\n\n\ndef test_load(audio_data):\n    \"\"\"Test audio loading.\"\"\"\n    waveform, sample_rate = audio_data\n    assert sample_rate == TEST_SAMPLE_RATE\n    assert waveform is not None\n    assert waveform.dtype == np.dtype(\"float32\")\n    assert len(waveform.shape) == 2\n    assert waveform.shape[0] == 479832\n    assert waveform.shape[1] == 2\n\n\ndef test_load_error(adapter):\n    \"\"\"Test load ffprobe exception\"\"\"\n    with raises(SpleeterError):\n        adapter.load(\"Paris City Jazz\", TEST_OFFSET, TEST_DURATION, TEST_SAMPLE_RATE)\n\n\ndef test_save(adapter, audio_data):\n    \"\"\"Test audio saving.\"\"\"\n    with TemporaryDirectory() as directory:\n        path = join(directory, \"ffmpeg-save.mp3\")\n        adapter.save(path, audio_data[0], audio_data[1])\n        probe = ffmpeg.probe(TEST_AUDIO_DESCRIPTOR)\n        assert len(probe[\"streams\"]) == 1\n        stream = probe[\"streams\"][0]\n        assert stream[\"codec_type\"] == \"audio\"\n        assert stream[\"channels\"] == 2\n        assert stream[\"duration\"] == \"10.919184\"\n"
  },
  {
    "path": "tests/test_github_model_provider.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" TO DOCUMENT \"\"\"\n\nfrom pytest import raises\n\nfrom spleeter.model.provider import ModelProvider\n\n\ndef test_checksum():\n    \"\"\"Test archive checksum index retrieval.\"\"\"\n    provider = ModelProvider.default()\n    assert (\n        provider.checksum(\"2stems\")\n        == \"f3a90b39dd2874269e8b05a48a86745df897b848c61f3958efc80a39152bd692\"\n    )\n    assert (\n        provider.checksum(\"4stems\")\n        == \"3adb4a50ad4eb18c7c4d65fcf4cf2367a07d48408a5eb7d03cd20067429dfaa8\"\n    )\n    assert (\n        provider.checksum(\"5stems\")\n        == \"25a1e87eb5f75cc72a4d2d5467a0a50ac75f05611f877c278793742513cc7218\"\n    )\n    with raises(ValueError):\n        provider.checksum(\"laisse moi stems stems stems\")\n"
  },
  {
    "path": "tests/test_separator.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Unit testing for Separator class. \"\"\"\n\n__email__ = \"spleeter@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nimport itertools\nfrom os.path import basename, exists, join, splitext\nfrom tempfile import TemporaryDirectory\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf  # type: ignore\n\nfrom spleeter import SpleeterError\nfrom spleeter.audio.adapter import AudioAdapter\nfrom spleeter.separator import Separator\n\nTEST_AUDIO_DESCRIPTORS = [\"audio_example.mp3\", \"audio_example_mono.mp3\"]\nMODELS = [\"spleeter:2stems\", \"spleeter:4stems\", \"spleeter:5stems\"]\n\nMODEL_TO_INST = {\n    \"spleeter:2stems\": (\"vocals\", \"accompaniment\"),\n    \"spleeter:4stems\": (\"vocals\", \"drums\", \"bass\", \"other\"),\n    \"spleeter:5stems\": (\"vocals\", \"drums\", \"bass\", \"piano\", \"other\"),\n}\n\n\nMODELS_AND_TEST_FILES = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))\nTEST_CONFIGURATIONS = list(itertools.product(TEST_AUDIO_DESCRIPTORS, MODELS))\n\n\nprint(\"RUNNING TESTS WITH TF VERSION {}\".format(tf.__version__))\n\n\n@pytest.mark.parametrize(\"test_file, configuration\", TEST_CONFIGURATIONS)\ndef test_separate(test_file, configuration):\n    \"\"\"Test separation from raw data.\"\"\"\n    instruments = MODEL_TO_INST[configuration]\n    adapter = AudioAdapter.default()\n    waveform, _ = adapter.load(test_file)\n    separator = Separator(configuration, multiprocess=False)\n    prediction = separator.separate(waveform, test_file)\n    assert len(prediction) == len(instruments)\n    for instrument in instruments:\n        assert instrument in prediction\n    for instrument in instruments:\n        track = prediction[instrument]\n        assert waveform.shape[:-1] == track.shape[:-1]\n        assert not np.allclose(waveform, track)\n        for compared in instruments:\n            if instrument != compared:\n                assert not np.allclose(track, prediction[compared])\n\n\n@pytest.mark.parametrize(\"test_file, configuration\", TEST_CONFIGURATIONS)\ndef test_separate_to_file(test_file, configuration):\n    \"\"\"Test file based separation.\"\"\"\n    instruments = MODEL_TO_INST[configuration]\n    separator = Separator(configuration, multiprocess=False)\n    name = splitext(basename(test_file))[0]\n    with TemporaryDirectory() as directory:\n        separator.separate_to_file(test_file, directory)\n        for instrument in instruments:\n            assert exists(join(directory, \"{}/{}.wav\".format(name, instrument)))\n\n\n@pytest.mark.parametrize(\"test_file, configuration\", TEST_CONFIGURATIONS)\ndef test_filename_format(test_file, configuration):\n    \"\"\"Test custom filename format.\"\"\"\n    instruments = MODEL_TO_INST[configuration]\n    separator = Separator(configuration, multiprocess=False)\n    name = splitext(basename(test_file))[0]\n    with TemporaryDirectory() as directory:\n        separator.separate_to_file(\n            test_file,\n            directory,\n            filename_format=\"export/{filename}/{instrument}.{codec}\",\n        )\n        for instrument in instruments:\n            assert exists(join(directory, \"export/{}/{}.wav\".format(name, instrument)))\n\n\n@pytest.mark.parametrize(\"test_file, configuration\", MODELS_AND_TEST_FILES)\ndef test_filename_conflict(test_file, configuration):\n    \"\"\"Test error handling with static pattern.\"\"\"\n    separator = Separator(configuration, multiprocess=False)\n    with TemporaryDirectory() as directory:\n        with pytest.raises(SpleeterError):\n            separator.separate_to_file(\n                test_file, directory, filename_format=\"I wanna be your lover\"\n            )\n"
  },
  {
    "path": "tests/test_train.py",
    "content": "#!/usr/bin/env python\n# coding: utf8\n\n\"\"\" Unit testing for Separator class. \"\"\"\n\n__email__ = \"research@deezer.com\"\n__author__ = \"Deezer Research\"\n__license__ = \"MIT License\"\n\nimport json\nimport os\nfrom os import makedirs\nfrom os.path import join\nfrom tempfile import TemporaryDirectory\n\nimport numpy as np\nimport pandas as pd  # type: ignore\nfrom typer.testing import CliRunner\n\nfrom spleeter.__main__ import spleeter\nfrom spleeter.audio.adapter import AudioAdapter\n\nTRAIN_CONFIG = {\n    \"mix_name\": \"mix\",\n    \"instrument_list\": [\"vocals\", \"other\"],\n    \"sample_rate\": 44100,\n    \"frame_length\": 4096,\n    \"frame_step\": 1024,\n    \"T\": 128,\n    \"F\": 128,\n    \"n_channels\": 2,\n    \"chunk_duration\": 4,\n    \"n_chunks_per_song\": 1,\n    \"separation_exponent\": 2,\n    \"mask_extension\": \"zeros\",\n    \"learning_rate\": 1e-4,\n    \"batch_size\": 2,\n    \"train_max_steps\": 10,\n    \"throttle_secs\": 20,\n    \"save_checkpoints_steps\": 100,\n    \"save_summary_steps\": 5,\n    \"random_seed\": 0,\n    \"model\": {\n        \"type\": \"unet.unet\",\n        \"params\": {\"conv_activation\": \"ELU\", \"deconv_activation\": \"ELU\"},\n    },\n}\n\n\ndef generate_fake_training_dataset(\n    path,\n    instrument_list=[\"vocals\", \"other\"],\n    n_channels=2,\n    n_songs=2,\n    fs=44100,\n    duration=6,\n):\n    \"\"\"\n    generates a fake training dataset in path:\n    - generates audio files\n    - generates a csv file describing the dataset\n    \"\"\"\n    aa = AudioAdapter.default()\n    rng = np.random.RandomState(seed=0)\n    dataset_df = pd.DataFrame(\n        columns=[\"mix_path\"]\n        + [f\"{instr}_path\" for instr in instrument_list]\n        + [\"duration\"]\n    )\n    for song in range(n_songs):\n        song_path = join(path, \"train\", f\"song{song}\")\n        makedirs(song_path, exist_ok=True)\n        dataset_df.loc[song, \"duration\"] = duration\n        for instr in instrument_list + [\"mix\"]:\n            filename = join(song_path, f\"{instr}.wav\")\n            data = rng.rand(duration * fs, n_channels) - 0.5\n            aa.save(filename, data, fs)\n            dataset_df.loc[song, f\"{instr}_path\"] = join(\n                \"train\", f\"song{song}\", f\"{instr}.wav\"\n            )\n    dataset_df.to_csv(join(path, \"train\", \"train.csv\"), index=False)\n\n\ndef test_train():\n\n    with TemporaryDirectory() as path:\n        # generate training dataset\n        for n_channels in [1, 2]:\n            TRAIN_CONFIG[\"n_channels\"] = n_channels\n            generate_fake_training_dataset(\n                path, n_channels=n_channels, fs=TRAIN_CONFIG[\"sample_rate\"]\n            )\n            # set training command arguments\n            runner = CliRunner()\n\n            model_dir = join(path, f\"model_{n_channels}\")\n            train_dir = join(path, \"train\")\n            cache_dir = join(path, f\"cache_{n_channels}\")\n\n            TRAIN_CONFIG[\"train_csv\"] = join(train_dir, \"train.csv\")\n            TRAIN_CONFIG[\"validation_csv\"] = join(train_dir, \"train.csv\")\n            TRAIN_CONFIG[\"model_dir\"] = model_dir\n            TRAIN_CONFIG[\"training_cache\"] = join(cache_dir, \"training\")\n            TRAIN_CONFIG[\"validation_cache\"] = join(cache_dir, \"validation\")\n            with open(\"useless_config.json\", \"w\") as stream:\n                json.dump(TRAIN_CONFIG, stream)\n\n            # execute training\n            result = runner.invoke(\n                spleeter,\n                [\"train\", \"-p\", \"useless_config.json\", \"-d\", path, \"--verbose\"],\n            )\n\n            # assert that model checkpoint was created.\n            assert os.path.exists(join(model_dir, \"model.ckpt-10.index\"))\n            assert os.path.exists(join(model_dir, \"checkpoint\"))\n            assert os.path.exists(join(model_dir, \"model.ckpt-0.meta\"))\n            assert result.exit_code == 0\n"
  }
]